asyncio.wait_for() is too buggy. use util.wait_for2() instead
wasted some time because asyncio.wait_for() was suppressing cancellations. [0][1][2] deja vu... [3] Looks like this is finally getting fixed in cpython 3.12 [4] So far away... In attempt to avoid encountering this again, let's try using asyncio.timeout in 3.11, which is how upstream reimplemented wait_for in 3.12 [4], and aiorpcx.timeout_after in 3.8-3.10. [0] https://github.com/python/cpython/issues/86296 [1] https://bugs.python.org/issue42130 [2] https://bugs.python.org/issue45098 [3] https://github.com/kyuupichan/aiorpcX/issues/44 [4] https://github.com/python/cpython/pull/98518
This commit is contained in:
@@ -166,7 +166,7 @@ class NotificationSession(RPCSession):
|
|||||||
try:
|
try:
|
||||||
# note: RPCSession.send_request raises TaskTimeout in case of a timeout.
|
# note: RPCSession.send_request raises TaskTimeout in case of a timeout.
|
||||||
# TaskTimeout is a subclass of CancelledError, which is *suppressed* in TaskGroups
|
# TaskTimeout is a subclass of CancelledError, which is *suppressed* in TaskGroups
|
||||||
response = await asyncio.wait_for(
|
response = await util.wait_for2(
|
||||||
super().send_request(*args, **kwargs),
|
super().send_request(*args, **kwargs),
|
||||||
timeout)
|
timeout)
|
||||||
except (TaskTimeout, asyncio.TimeoutError) as e:
|
except (TaskTimeout, asyncio.TimeoutError) as e:
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import functools
|
|||||||
|
|
||||||
import aiorpcx
|
import aiorpcx
|
||||||
from aiorpcx import ignore_after
|
from aiorpcx import ignore_after
|
||||||
|
from async_timeout import timeout
|
||||||
|
|
||||||
from .crypto import sha256, sha256d
|
from .crypto import sha256, sha256d
|
||||||
from . import bitcoin, util
|
from . import bitcoin, util
|
||||||
@@ -331,7 +332,7 @@ class Peer(Logger):
|
|||||||
|
|
||||||
async def wait_for_message(self, expected_name: str, channel_id: bytes):
|
async def wait_for_message(self, expected_name: str, channel_id: bytes):
|
||||||
q = self.ordered_message_queues[channel_id]
|
q = self.ordered_message_queues[channel_id]
|
||||||
name, payload = await asyncio.wait_for(q.get(), LN_P2P_NETWORK_TIMEOUT)
|
name, payload = await util.wait_for2(q.get(), LN_P2P_NETWORK_TIMEOUT)
|
||||||
# raise exceptions for errors, so that the caller sees them
|
# raise exceptions for errors, so that the caller sees them
|
||||||
if (err_bytes := payload.get("error")) is not None:
|
if (err_bytes := payload.get("error")) is not None:
|
||||||
err_text = error_text_bytes_to_safe_str(err_bytes)
|
err_text = error_text_bytes_to_safe_str(err_bytes)
|
||||||
@@ -460,12 +461,12 @@ class Peer(Logger):
|
|||||||
|
|
||||||
async def query_gossip(self):
|
async def query_gossip(self):
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(self.initialized, LN_P2P_NETWORK_TIMEOUT)
|
await util.wait_for2(self.initialized, LN_P2P_NETWORK_TIMEOUT)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise GracefulDisconnect(f"Failed to initialize: {e!r}") from e
|
raise GracefulDisconnect(f"Failed to initialize: {e!r}") from e
|
||||||
if self.lnworker == self.lnworker.network.lngossip:
|
if self.lnworker == self.lnworker.network.lngossip:
|
||||||
try:
|
try:
|
||||||
ids, complete = await asyncio.wait_for(self.get_channel_range(), LN_P2P_NETWORK_TIMEOUT)
|
ids, complete = await util.wait_for2(self.get_channel_range(), LN_P2P_NETWORK_TIMEOUT)
|
||||||
except asyncio.TimeoutError as e:
|
except asyncio.TimeoutError as e:
|
||||||
raise GracefulDisconnect("query_channel_range timed out") from e
|
raise GracefulDisconnect("query_channel_range timed out") from e
|
||||||
self.logger.info('Received {} channel ids. (complete: {})'.format(len(ids), complete))
|
self.logger.info('Received {} channel ids. (complete: {})'.format(len(ids), complete))
|
||||||
@@ -575,7 +576,7 @@ class Peer(Logger):
|
|||||||
|
|
||||||
async def _message_loop(self):
|
async def _message_loop(self):
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(self.initialize(), LN_P2P_NETWORK_TIMEOUT)
|
await util.wait_for2(self.initialize(), LN_P2P_NETWORK_TIMEOUT)
|
||||||
except (OSError, asyncio.TimeoutError, HandshakeFailed) as e:
|
except (OSError, asyncio.TimeoutError, HandshakeFailed) as e:
|
||||||
raise GracefulDisconnect(f'initialize failed: {repr(e)}') from e
|
raise GracefulDisconnect(f'initialize failed: {repr(e)}') from e
|
||||||
async for msg in self.transport.read_messages():
|
async for msg in self.transport.read_messages():
|
||||||
@@ -699,7 +700,7 @@ class Peer(Logger):
|
|||||||
Channel configurations are initialized in this method.
|
Channel configurations are initialized in this method.
|
||||||
"""
|
"""
|
||||||
# will raise if init fails
|
# will raise if init fails
|
||||||
await asyncio.wait_for(self.initialized, LN_P2P_NETWORK_TIMEOUT)
|
await util.wait_for2(self.initialized, LN_P2P_NETWORK_TIMEOUT)
|
||||||
# trampoline is not yet in features
|
# trampoline is not yet in features
|
||||||
if self.lnworker.uses_trampoline() and not self.lnworker.is_trampoline_peer(self.pubkey):
|
if self.lnworker.uses_trampoline() and not self.lnworker.is_trampoline_peer(self.pubkey):
|
||||||
raise Exception('Not a trampoline node: ' + str(self.their_features))
|
raise Exception('Not a trampoline node: ' + str(self.their_features))
|
||||||
|
|||||||
@@ -1071,7 +1071,7 @@ class LNWallet(LNWorker):
|
|||||||
funding_sat=funding_sat,
|
funding_sat=funding_sat,
|
||||||
push_msat=push_sat * 1000,
|
push_msat=push_sat * 1000,
|
||||||
temp_channel_id=os.urandom(32))
|
temp_channel_id=os.urandom(32))
|
||||||
chan, funding_tx = await asyncio.wait_for(coro, LN_P2P_NETWORK_TIMEOUT)
|
chan, funding_tx = await util.wait_for2(coro, LN_P2P_NETWORK_TIMEOUT)
|
||||||
util.trigger_callback('channels_updated', self.wallet)
|
util.trigger_callback('channels_updated', self.wallet)
|
||||||
self.wallet.adb.add_transaction(funding_tx) # save tx as local into the wallet
|
self.wallet.adb.add_transaction(funding_tx) # save tx as local into the wallet
|
||||||
self.wallet.sign_transaction(funding_tx, password)
|
self.wallet.sign_transaction(funding_tx, password)
|
||||||
|
|||||||
@@ -811,7 +811,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
|
|||||||
# note: using longer timeouts here as DNS can sometimes be slow!
|
# note: using longer timeouts here as DNS can sometimes be slow!
|
||||||
timeout = self.get_network_timeout_seconds(NetworkTimeout.Generic)
|
timeout = self.get_network_timeout_seconds(NetworkTimeout.Generic)
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(interface.ready, timeout)
|
await util.wait_for2(interface.ready, timeout)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
self.logger.info(f"couldn't launch iface {server} -- {repr(e)}")
|
self.logger.info(f"couldn't launch iface {server} -- {repr(e)}")
|
||||||
await interface.close()
|
await interface.close()
|
||||||
@@ -1401,7 +1401,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
|
|||||||
async def get_response(server: ServerAddr):
|
async def get_response(server: ServerAddr):
|
||||||
interface = Interface(network=self, server=server, proxy=self.proxy)
|
interface = Interface(network=self, server=server, proxy=self.proxy)
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(interface.ready, timeout)
|
await util.wait_for2(interface.ready, timeout)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
await interface.close()
|
await interface.close()
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from aiorpcx import NetAddress
|
from aiorpcx import NetAddress
|
||||||
|
|
||||||
|
from electrum import util
|
||||||
from electrum.util import log_exceptions, ignore_exceptions
|
from electrum.util import log_exceptions, ignore_exceptions
|
||||||
from electrum.plugin import BasePlugin, hook
|
from electrum.plugin import BasePlugin, hook
|
||||||
from electrum.logging import Logger
|
from electrum.logging import Logger
|
||||||
@@ -173,7 +174,7 @@ class PayServer(Logger, EventListener):
|
|||||||
return ws
|
return ws
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(self.pending[key].wait(), 1)
|
await util.wait_for2(self.pending[key].wait(), 1)
|
||||||
break
|
break
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
# send data on the websocket, to keep it alive
|
# send data on the websocket, to keep it alive
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import time
|
|||||||
|
|
||||||
from electrum.logging import get_logger, configure_logging
|
from electrum.logging import get_logger, configure_logging
|
||||||
from electrum.simple_config import SimpleConfig
|
from electrum.simple_config import SimpleConfig
|
||||||
from electrum import constants
|
from electrum import constants, util
|
||||||
from electrum.daemon import Daemon
|
from electrum.daemon import Daemon
|
||||||
from electrum.wallet import create_new_wallet
|
from electrum.wallet import create_new_wallet
|
||||||
from electrum.util import create_and_start_event_loop, log_exceptions, bfh
|
from electrum.util import create_and_start_event_loop, log_exceptions, bfh
|
||||||
@@ -84,7 +84,7 @@ async def worker(work_queue: asyncio.Queue, results_queue: asyncio.Queue, flag):
|
|||||||
print(f"worker connecting to {connect_str}")
|
print(f"worker connecting to {connect_str}")
|
||||||
try:
|
try:
|
||||||
peer = await wallet.lnworker.add_peer(connect_str)
|
peer = await wallet.lnworker.add_peer(connect_str)
|
||||||
res = await asyncio.wait_for(peer.initialized, TIMEOUT)
|
res = await util.wait_for2(peer.initialized, TIMEOUT)
|
||||||
if res:
|
if res:
|
||||||
if peer.features & flag == work['features'] & flag:
|
if peer.features & flag == work['features'] & flag:
|
||||||
await results_queue.put(True)
|
await results_queue.put(True)
|
||||||
|
|||||||
@@ -824,8 +824,8 @@ class TestPeer(ElectrumTestCase):
|
|||||||
alice_channel, bob_channel = create_test_channels()
|
alice_channel, bob_channel = create_test_channels()
|
||||||
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
||||||
async def pay():
|
async def pay():
|
||||||
await asyncio.wait_for(p1.initialized, 1)
|
await util.wait_for2(p1.initialized, 1)
|
||||||
await asyncio.wait_for(p2.initialized, 1)
|
await util.wait_for2(p2.initialized, 1)
|
||||||
# prep
|
# prep
|
||||||
_maybe_send_commitment1 = p1.maybe_send_commitment
|
_maybe_send_commitment1 = p1.maybe_send_commitment
|
||||||
_maybe_send_commitment2 = p2.maybe_send_commitment
|
_maybe_send_commitment2 = p2.maybe_send_commitment
|
||||||
@@ -1374,8 +1374,8 @@ class TestPeer(ElectrumTestCase):
|
|||||||
w2.enable_htlc_settle = False
|
w2.enable_htlc_settle = False
|
||||||
lnaddr, pay_req = self.prepare_invoice(w2)
|
lnaddr, pay_req = self.prepare_invoice(w2)
|
||||||
async def pay():
|
async def pay():
|
||||||
await asyncio.wait_for(p1.initialized, 1)
|
await util.wait_for2(p1.initialized, 1)
|
||||||
await asyncio.wait_for(p2.initialized, 1)
|
await util.wait_for2(p2.initialized, 1)
|
||||||
# alice sends htlc
|
# alice sends htlc
|
||||||
route, amount_msat = (await w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0:2]
|
route, amount_msat = (await w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0:2]
|
||||||
p1.pay(route=route,
|
p1.pay(route=route,
|
||||||
@@ -1401,8 +1401,8 @@ class TestPeer(ElectrumTestCase):
|
|||||||
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
||||||
|
|
||||||
async def action():
|
async def action():
|
||||||
await asyncio.wait_for(p1.initialized, 1)
|
await util.wait_for2(p1.initialized, 1)
|
||||||
await asyncio.wait_for(p2.initialized, 1)
|
await util.wait_for2(p2.initialized, 1)
|
||||||
await p1.send_warning(alice_channel.channel_id, 'be warned!', close_connection=True)
|
await p1.send_warning(alice_channel.channel_id, 'be warned!', close_connection=True)
|
||||||
gath = asyncio.gather(action(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
|
gath = asyncio.gather(action(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
|
||||||
with self.assertRaises(GracefulDisconnect):
|
with self.assertRaises(GracefulDisconnect):
|
||||||
@@ -1414,8 +1414,8 @@ class TestPeer(ElectrumTestCase):
|
|||||||
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
||||||
|
|
||||||
async def action():
|
async def action():
|
||||||
await asyncio.wait_for(p1.initialized, 1)
|
await util.wait_for2(p1.initialized, 1)
|
||||||
await asyncio.wait_for(p2.initialized, 1)
|
await util.wait_for2(p2.initialized, 1)
|
||||||
await p1.send_error(alice_channel.channel_id, 'some error happened!', force_close_channel=True)
|
await p1.send_error(alice_channel.channel_id, 'some error happened!', force_close_channel=True)
|
||||||
assert alice_channel.is_closed()
|
assert alice_channel.is_closed()
|
||||||
gath.cancel()
|
gath.cancel()
|
||||||
@@ -1447,8 +1447,8 @@ class TestPeer(ElectrumTestCase):
|
|||||||
|
|
||||||
async def test():
|
async def test():
|
||||||
async def close():
|
async def close():
|
||||||
await asyncio.wait_for(p1.initialized, 1)
|
await util.wait_for2(p1.initialized, 1)
|
||||||
await asyncio.wait_for(p2.initialized, 1)
|
await util.wait_for2(p2.initialized, 1)
|
||||||
# bob closes channel with different shutdown script
|
# bob closes channel with different shutdown script
|
||||||
await p1.close_channel(alice_channel.channel_id)
|
await p1.close_channel(alice_channel.channel_id)
|
||||||
gath.cancel()
|
gath.cancel()
|
||||||
@@ -1477,8 +1477,8 @@ class TestPeer(ElectrumTestCase):
|
|||||||
|
|
||||||
async def test():
|
async def test():
|
||||||
async def close():
|
async def close():
|
||||||
await asyncio.wait_for(p1.initialized, 1)
|
await util.wait_for2(p1.initialized, 1)
|
||||||
await asyncio.wait_for(p2.initialized, 1)
|
await util.wait_for2(p2.initialized, 1)
|
||||||
await p1.close_channel(alice_channel.channel_id)
|
await p1.close_channel(alice_channel.channel_id)
|
||||||
gath.cancel()
|
gath.cancel()
|
||||||
|
|
||||||
@@ -1538,8 +1538,8 @@ class TestPeer(ElectrumTestCase):
|
|||||||
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
||||||
|
|
||||||
async def send_weird_messages():
|
async def send_weird_messages():
|
||||||
await asyncio.wait_for(p1.initialized, 1)
|
await util.wait_for2(p1.initialized, 1)
|
||||||
await asyncio.wait_for(p2.initialized, 1)
|
await util.wait_for2(p2.initialized, 1)
|
||||||
# peer1 sends known message with trailing garbage
|
# peer1 sends known message with trailing garbage
|
||||||
# BOLT-01 says peer2 should ignore trailing garbage
|
# BOLT-01 says peer2 should ignore trailing garbage
|
||||||
raw_msg1 = encode_msg('ping', num_pong_bytes=4, byteslen=4) + bytes(range(55))
|
raw_msg1 = encode_msg('ping', num_pong_bytes=4, byteslen=4) + bytes(range(55))
|
||||||
@@ -1570,8 +1570,8 @@ class TestPeer(ElectrumTestCase):
|
|||||||
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
||||||
|
|
||||||
async def send_weird_messages():
|
async def send_weird_messages():
|
||||||
await asyncio.wait_for(p1.initialized, 1)
|
await util.wait_for2(p1.initialized, 1)
|
||||||
await asyncio.wait_for(p2.initialized, 1)
|
await util.wait_for2(p2.initialized, 1)
|
||||||
# peer1 sends unknown 'even-type' message
|
# peer1 sends unknown 'even-type' message
|
||||||
# BOLT-01 says peer2 should close the connection
|
# BOLT-01 says peer2 should close the connection
|
||||||
raw_msg2 = (43334).to_bytes(length=2, byteorder="big") + bytes(range(55))
|
raw_msg2 = (43334).to_bytes(length=2, byteorder="big") + bytes(range(55))
|
||||||
@@ -1600,8 +1600,8 @@ class TestPeer(ElectrumTestCase):
|
|||||||
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
||||||
|
|
||||||
async def send_weird_messages():
|
async def send_weird_messages():
|
||||||
await asyncio.wait_for(p1.initialized, 1)
|
await util.wait_for2(p1.initialized, 1)
|
||||||
await asyncio.wait_for(p2.initialized, 1)
|
await util.wait_for2(p2.initialized, 1)
|
||||||
# peer1 sends known message with insufficient length for the contents
|
# peer1 sends known message with insufficient length for the contents
|
||||||
# BOLT-01 says peer2 should fail the connection
|
# BOLT-01 says peer2 should fail the connection
|
||||||
raw_msg1 = encode_msg('ping', num_pong_bytes=4, byteslen=4)[:-1]
|
raw_msg1 = encode_msg('ping', num_pong_bytes=4, byteslen=4)[:-1]
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ import binascii
|
|||||||
import os, sys, re, json
|
import os, sys, re, json
|
||||||
from collections import defaultdict, OrderedDict
|
from collections import defaultdict, OrderedDict
|
||||||
from typing import (NamedTuple, Union, TYPE_CHECKING, Tuple, Optional, Callable, Any,
|
from typing import (NamedTuple, Union, TYPE_CHECKING, Tuple, Optional, Callable, Any,
|
||||||
Sequence, Dict, Generic, TypeVar, List, Iterable, Set)
|
Sequence, Dict, Generic, TypeVar, List, Iterable, Set, Awaitable)
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import decimal
|
import decimal
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
@@ -1371,6 +1371,36 @@ aiorpcx.curio._set_task_deadline = _aiorpcx_monkeypatched_set_task_deadline
|
|||||||
aiorpcx.curio._unset_task_deadline = _aiorpcx_monkeypatched_unset_task_deadline
|
aiorpcx.curio._unset_task_deadline = _aiorpcx_monkeypatched_unset_task_deadline
|
||||||
|
|
||||||
|
|
||||||
|
async def wait_for2(fut: Awaitable, timeout: Union[int, float, None]):
|
||||||
|
"""Replacement for asyncio.wait_for,
|
||||||
|
due to bugs: https://bugs.python.org/issue42130 and https://github.com/python/cpython/issues/86296 ,
|
||||||
|
which are only fixed in python 3.12+.
|
||||||
|
"""
|
||||||
|
if sys.version_info[:3] >= (3, 12):
|
||||||
|
return await asyncio.wait_for(fut, timeout)
|
||||||
|
else:
|
||||||
|
async with async_timeout(timeout):
|
||||||
|
return await asyncio.ensure_future(fut, loop=get_running_loop())
|
||||||
|
|
||||||
|
|
||||||
|
if hasattr(asyncio, 'timeout'): # python 3.11+
|
||||||
|
async_timeout = asyncio.timeout
|
||||||
|
else:
|
||||||
|
class TimeoutAfterAsynciolike(aiorpcx.curio.TimeoutAfter):
|
||||||
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||||
|
try:
|
||||||
|
await super().__aexit__(exc_type, exc_value, traceback)
|
||||||
|
except (aiorpcx.TaskTimeout, aiorpcx.UncaughtTimeoutError):
|
||||||
|
raise asyncio.TimeoutError from None
|
||||||
|
except aiorpcx.TimeoutCancellationError:
|
||||||
|
raise asyncio.CancelledError from None
|
||||||
|
|
||||||
|
def async_timeout(delay: Union[int, float, None]):
|
||||||
|
if delay is None:
|
||||||
|
return nullcontext()
|
||||||
|
return TimeoutAfterAsynciolike(delay)
|
||||||
|
|
||||||
|
|
||||||
class NetworkJobOnDefaultServer(Logger, ABC):
|
class NetworkJobOnDefaultServer(Logger, ABC):
|
||||||
"""An abstract base class for a job that runs on the main network
|
"""An abstract base class for a job that runs on the main network
|
||||||
interface. Every time the main interface changes, the job is
|
interface. Every time the main interface changes, the job is
|
||||||
|
|||||||
Reference in New Issue
Block a user