1
0

interface: small clean-up. intro ChainResolutionMode.

- type hints
- minor API changes
- no functional changes
This commit is contained in:
SomberNight
2025-06-06 16:42:15 +00:00
parent 2b5147eb4d
commit 27599ac537
5 changed files with 127 additions and 71 deletions

View File

@@ -38,6 +38,7 @@ import logging
import hashlib
import functools
import random
import enum
import aiorpcx
from aiorpcx import RPCSession, Notification, NetAddress, NewlineFramer
@@ -132,6 +133,14 @@ def assert_list_or_tuple(val: Any) -> None:
raise RequestCorrupted(f'{val!r} should be a list or tuple')
class ChainResolutionMode(enum.Enum):
CATCHUP = enum.auto()
BACKWARD = enum.auto()
BINARY = enum.auto()
FORK = enum.auto()
NO_FORK = enum.auto()
class NotificationSession(RPCSession):
def __init__(self, *args, interface: 'Interface', **kwargs):
@@ -510,7 +519,7 @@ class Interface(Logger):
# Note that these values are updated before they are verified.
# Especially during initial header sync, verification can take a long time.
# Failing verification will get the interface closed.
self.tip_header = None
self.tip_header = None # type: Optional[dict]
self.tip = 0
self.fee_estimates_eta = {} # type: Dict[int, int]
@@ -543,13 +552,13 @@ class Interface(Logger):
def __str__(self):
return f"<Interface {self.diagnostic_name()}>"
async def is_server_ca_signed(self, ca_ssl_context):
async def is_server_ca_signed(self, ca_ssl_context: ssl.SSLContext) -> bool:
"""Given a CA enforcing SSL context, returns True if the connection
can be established. Returns False if the server has a self-signed
certificate but otherwise is okay. Any other failures raise.
"""
try:
await self.open_session(ca_ssl_context, exit_early=True)
await self.open_session(ssl_context=ca_ssl_context, exit_early=True)
except ConnectError as e:
cause = e.__cause__
if (isinstance(cause, ssl.SSLCertVerificationError)
@@ -562,7 +571,7 @@ class Interface(Logger):
# Good. We will use this server as CA-signed.
return True
async def _try_saving_ssl_cert_for_first_time(self, ca_ssl_context):
async def _try_saving_ssl_cert_for_first_time(self, ca_ssl_context: ssl.SSLContext) -> None:
ca_signed = await self.is_server_ca_signed(ca_ssl_context)
if ca_signed:
if self._get_expected_fingerprint():
@@ -599,10 +608,10 @@ class Interface(Logger):
self.logger.info(f"certificate has expired: {e}")
os.unlink(self.cert_path) # delete pinned cert only in this case
return False
self._verify_certificate_fingerprint(bytearray(b))
self._verify_certificate_fingerprint(bytes(b))
return True
async def _get_ssl_context(self):
async def _get_ssl_context(self) -> Optional[ssl.SSLContext]:
if self.protocol != 's':
# using plaintext TCP
return None
@@ -658,7 +667,7 @@ class Interface(Logger):
self.logger.info(f'disconnecting due to: {repr(e)}')
return
try:
await self.open_session(ssl_context)
await self.open_session(ssl_context=ssl_context)
except (asyncio.CancelledError, ConnectError, aiorpcx.socks.SOCKSError) as e:
# make SSL errors for main interface more visible (to help servers ops debug cert pinning issues)
if (isinstance(e, ConnectError) and isinstance(e.__cause__, ssl.SSLError)
@@ -731,8 +740,9 @@ class Interface(Logger):
def _get_expected_fingerprint(self) -> Optional[str]:
if self.is_main_server():
return self.network.config.NETWORK_SERVERFINGERPRINT
return None
def _verify_certificate_fingerprint(self, certificate):
def _verify_certificate_fingerprint(self, certificate: bytes) -> None:
expected_fingerprint = self._get_expected_fingerprint()
if not expected_fingerprint:
return
@@ -743,21 +753,27 @@ class Interface(Logger):
raise ErrorSSLCertFingerprintMismatch('Refusing to connect to server due to cert fingerprint mismatch')
self.logger.info("cert fingerprint verification passed")
async def get_block_header(self, height, assert_mode):
async def get_block_header(self, height: int, *, mode: ChainResolutionMode) -> dict:
if not is_non_negative_integer(height):
raise Exception(f"{repr(height)} is not a block height")
self.logger.info(f'requesting block header {height} in mode {assert_mode}')
self.logger.info(f'requesting block header {height} in {mode=}')
# use lower timeout as we usually have network.bhi_lock here
timeout = self.network.get_network_timeout_seconds(NetworkTimeout.Urgent)
res = await self.session.send_request('blockchain.block.header', [height], timeout=timeout)
return blockchain.deserialize_header(bytes.fromhex(res), height)
async def request_chunk(self, height: int, tip=None, *, can_return_early=False):
async def request_chunk(
self,
height: int,
*,
tip: Optional[int] = None,
can_return_early: bool = False,
) -> Optional[Tuple[bool, int]]:
if not is_non_negative_integer(height):
raise Exception(f"{repr(height)} is not a block height")
index = height // 2016
if can_return_early and index in self._requested_chunks:
return
return None
self.logger.info(f"requesting chunk from height {height}")
size = 2016
if tip is not None:
@@ -790,12 +806,17 @@ class Interface(Logger):
return (self.network.interface == self or
self.network.interface is None and self.network.default_server == self.server)
async def open_session(self, sslc, exit_early=False):
async def open_session(
self,
*,
ssl_context: Optional[ssl.SSLContext],
exit_early: bool = False,
):
session_factory = lambda *args, iface=self, **kwargs: NotificationSession(*args, **kwargs, interface=iface)
async with _RSClient(
session_factory=session_factory,
host=self.host, port=self.port,
ssl=sslc,
ssl=ssl_context,
proxy=self.proxy,
transport=PaddedRSTransport,
) as session:
@@ -918,20 +939,25 @@ class Interface(Logger):
if self.blockchain.height() >= height and self.blockchain.check_header(header):
# another interface amended the blockchain
return False
_, height = await self.step(height, header)
_, height = await self.step(height, header=header)
# in the simple case, height == self.tip+1
if height <= self.tip:
await self.sync_until(height)
return True
async def sync_until(self, height, next_height=None):
async def sync_until(
self,
height: int,
*,
next_height: Optional[int] = None,
) -> Tuple[ChainResolutionMode, int]:
if next_height is None:
next_height = self.tip
last = None
last = None # type: Optional[ChainResolutionMode]
while last is None or height <= next_height:
prev_last, prev_height = last, height
if next_height > height + 10:
could_connect, num_headers = await self.request_chunk(height, next_height)
if next_height > height + 10: # TODO make smarter. the protocol allows asking for n headers
could_connect, num_headers = await self.request_chunk(height, tip=next_height)
if not could_connect:
if height <= constants.net.max_checkpoint():
raise GracefulDisconnect('server chain conflicts with checkpoints or genesis')
@@ -941,16 +967,21 @@ class Interface(Logger):
util.trigger_callback('network_updated')
height = (height // 2016 * 2016) + num_headers
assert height <= next_height+1, (height, self.tip)
last = 'catchup'
last = ChainResolutionMode.CATCHUP
else:
last, height = await self.step(height)
assert (prev_last, prev_height) != (last, height), 'had to prevent infinite loop in interface.sync_until'
return last, height
async def step(self, height, header=None):
async def step(
self,
height: int,
*,
header: Optional[dict] = None, # at 'height'
) -> Tuple[ChainResolutionMode, int]:
assert 0 <= height <= self.tip, (height, self.tip)
if header is None:
header = await self.get_block_header(height, 'catchup')
header = await self.get_block_header(height, mode=ChainResolutionMode.CATCHUP)
chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
if chain:
@@ -959,12 +990,12 @@ class Interface(Logger):
# we might know the blockhash (enough for check_header) but
# not have the header itself. e.g. regtest chain with only genesis.
# this situation resolves itself on the next block
return 'catchup', height+1
return ChainResolutionMode.CATCHUP, height+1
can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height)
if not can_connect:
self.logger.info(f"can't connect new block: {height=}")
height, header, bad, bad_header = await self._search_headers_backwards(height, header)
height, header, bad, bad_header = await self._search_headers_backwards(height, header=header)
chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height)
assert chain or can_connect
@@ -974,12 +1005,18 @@ class Interface(Logger):
if isinstance(can_connect, Blockchain): # not when mocking
self.blockchain = can_connect
self.blockchain.save_header(header)
return 'catchup', height
return ChainResolutionMode.CATCHUP, height
good, bad, bad_header = await self._search_headers_binary(height, bad, bad_header, chain)
return await self._resolve_potential_chain_fork_given_forkpoint(good, bad, bad_header)
async def _search_headers_binary(self, height, bad, bad_header, chain):
async def _search_headers_binary(
self,
height: int,
bad: int,
bad_header: dict,
chain: Optional[Blockchain],
) -> Tuple[int, int, dict]:
assert bad == bad_header['block_height']
_assert_header_does_not_check_against_any_chain(bad_header)
@@ -989,7 +1026,7 @@ class Interface(Logger):
assert good < bad, (good, bad)
height = (good + bad) // 2
self.logger.info(f"binary step. good {good}, bad {bad}, height {height}")
header = await self.get_block_header(height, 'binary')
header = await self.get_block_header(height, mode=ChainResolutionMode.BINARY)
chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
if chain:
self.blockchain = chain if isinstance(chain, Blockchain) else self.blockchain
@@ -1009,7 +1046,12 @@ class Interface(Logger):
self.logger.info(f"binary search exited. good {good}, bad {bad}")
return good, bad, bad_header
async def _resolve_potential_chain_fork_given_forkpoint(self, good, bad, bad_header):
async def _resolve_potential_chain_fork_given_forkpoint(
self,
good: int,
bad: int,
bad_header: dict,
) -> Tuple[ChainResolutionMode, int]:
assert good + 1 == bad
assert bad == bad_header['block_height']
_assert_header_does_not_check_against_any_chain(bad_header)
@@ -1021,7 +1063,7 @@ class Interface(Logger):
if bh == good:
height = good + 1
self.logger.info(f"catching up from {height}")
return 'no_fork', height
return ChainResolutionMode.NO_FORK, height
# this is a new fork we don't yet have
height = bad + 1
@@ -1030,16 +1072,21 @@ class Interface(Logger):
b = forkfun(bad_header) # type: Blockchain
self.blockchain = b
assert b.forkpoint == bad
return 'fork', height
return ChainResolutionMode.FORK, height
async def _search_headers_backwards(self, height, header):
async def _search_headers_backwards(
self,
height: int,
*,
header: dict,
) -> Tuple[int, dict, int, dict]:
async def iterate():
nonlocal height, header
checkp = False
if height <= constants.net.max_checkpoint():
height = constants.net.max_checkpoint()
checkp = True
header = await self.get_block_header(height, 'backward')
header = await self.get_block_header(height, mode=ChainResolutionMode.BACKWARD)
chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height)
if chain or can_connect:

View File

@@ -102,7 +102,7 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
header = blockchain.read_header(block_height)
if header is None:
if block_height < constants.net.max_checkpoint():
await self.taskgroup.spawn(self.network.request_chunk(block_height, None, can_return_early=True))
await self.taskgroup.spawn(self.network.request_chunk(block_height, can_return_early=True))
continue
self.started_verifying_channel.add(short_channel_id)
await self.taskgroup.spawn(self.verify_channel(block_height, short_channel_id))

View File

@@ -1311,7 +1311,13 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
@best_effort_reliable
@catch_server_exceptions
async def request_chunk(self, height: int, tip=None, *, can_return_early=False):
async def request_chunk(
self,
height: int,
*,
tip: Optional[int] = None,
can_return_early: bool = False,
) -> Optional[Tuple[bool, int]]:
if self.interface is None: # handled by best_effort_reliable
raise RequestTimedOut()
return await self.interface.request_chunk(height, tip=tip, can_return_early=can_return_early)

View File

@@ -88,7 +88,7 @@ class SPV(NetworkJobOnDefaultServer):
if header is None:
if tx_height < constants.net.max_checkpoint():
# FIXME these requests are not counted (self._requests_sent += 1)
await self.taskgroup.spawn(self.interface.request_chunk(tx_height, None, can_return_early=True))
await self.taskgroup.spawn(self.interface.request_chunk(tx_height, can_return_early=True))
continue
# request now
self.logger.info(f'requested merkle {tx_hash}')

View File

@@ -5,7 +5,7 @@ import unittest
from electrum import constants
from electrum.simple_config import SimpleConfig
from electrum import blockchain
from electrum.interface import Interface, ServerAddr
from electrum.interface import Interface, ServerAddr, ChainResolutionMode
from electrum.crypto import sha256
from electrum.util import OldTaskGroup
from electrum import util
@@ -13,18 +13,21 @@ from electrum import util
from . import ElectrumTestCase
CRM = ChainResolutionMode
class MockNetwork:
def __init__(self):
def __init__(self, config: SimpleConfig):
self.config = config
self.asyncio_loop = util.get_asyncio_loop()
self.taskgroup = OldTaskGroup()
self.proxy = None
class MockInterface(Interface):
def __init__(self, config):
def __init__(self, config: SimpleConfig):
self.config = config
network = MockNetwork()
network.config = config
network = MockNetwork(config)
super().__init__(network=network, server=ServerAddr.from_str('mock-server:50000:t'))
self.q = asyncio.Queue()
self.blockchain = blockchain.Blockchain(config=self.config, forkpoint=0,
@@ -32,12 +35,12 @@ class MockInterface(Interface):
self.tip = 12
self.blockchain._size = self.tip + 1
async def get_block_header(self, height, assert_mode):
assert self.q.qsize() > 0, (height, assert_mode)
async def get_block_header(self, height: int, *, mode: ChainResolutionMode) -> dict:
assert self.q.qsize() > 0, (height, mode)
item = await self.q.get()
print("step with height", height, item)
assert item['block_height'] == height, (item['block_height'], height)
assert assert_mode in item['mock'], (assert_mode, item)
assert mode in item['mock'], (mode, item)
return item
async def run(self):
@@ -63,46 +66,46 @@ class TestNetwork(ElectrumTestCase):
async def test_fork_noconflict(self):
blockchain.blockchains = {}
self.interface.q.put_nowait({'block_height': 8, 'mock': {'catchup':1, 'check': lambda x: False, 'connect': lambda x: False}})
self.interface.q.put_nowait({'block_height': 8, 'mock': {CRM.CATCHUP:1, 'check': lambda x: False, 'connect': lambda x: False}})
def mock_connect(height):
return height == 6
self.interface.q.put_nowait({'block_height': 7, 'mock': {'backward':1,'check': lambda x: False, 'connect': mock_connect, 'fork': self.mock_fork}})
self.interface.q.put_nowait({'block_height': 2, 'mock': {'backward':1,'check':lambda x: True, 'connect': lambda x: False}})
self.interface.q.put_nowait({'block_height': 4, 'mock': {'binary':1,'check':lambda x: True, 'connect': lambda x: True}})
self.interface.q.put_nowait({'block_height': 5, 'mock': {'binary':1,'check':lambda x: True, 'connect': lambda x: True}})
self.interface.q.put_nowait({'block_height': 6, 'mock': {'binary':1,'check':lambda x: True, 'connect': lambda x: True}})
self.interface.q.put_nowait({'block_height': 7, 'mock': {CRM.BACKWARD:1,'check': lambda x: False, 'connect': mock_connect, 'fork': self.mock_fork}})
self.interface.q.put_nowait({'block_height': 2, 'mock': {CRM.BACKWARD:1,'check':lambda x: True, 'connect': lambda x: False}})
self.interface.q.put_nowait({'block_height': 4, 'mock': {CRM.BINARY:1,'check':lambda x: True, 'connect': lambda x: True}})
self.interface.q.put_nowait({'block_height': 5, 'mock': {CRM.BINARY:1,'check':lambda x: True, 'connect': lambda x: True}})
self.interface.q.put_nowait({'block_height': 6, 'mock': {CRM.BINARY:1,'check':lambda x: True, 'connect': lambda x: True}})
ifa = self.interface
res = await ifa.sync_until(8, next_height=7)
self.assertEqual(('fork', 8), res)
self.assertEqual((CRM.FORK, 8), res)
self.assertEqual(self.interface.q.qsize(), 0)
async def test_fork_conflict(self):
blockchain.blockchains = {7: {'check': lambda bad_header: False}}
self.interface.q.put_nowait({'block_height': 8, 'mock': {'catchup':1, 'check': lambda x: False, 'connect': lambda x: False}})
self.interface.q.put_nowait({'block_height': 8, 'mock': {CRM.CATCHUP:1, 'check': lambda x: False, 'connect': lambda x: False}})
def mock_connect(height):
return height == 6
self.interface.q.put_nowait({'block_height': 7, 'mock': {'backward':1,'check': lambda x: False, 'connect': mock_connect, 'fork': self.mock_fork}})
self.interface.q.put_nowait({'block_height': 2, 'mock': {'backward':1,'check':lambda x: True, 'connect': lambda x: False}})
self.interface.q.put_nowait({'block_height': 4, 'mock': {'binary':1,'check':lambda x: True, 'connect': lambda x: True}})
self.interface.q.put_nowait({'block_height': 5, 'mock': {'binary':1,'check':lambda x: True, 'connect': lambda x: True}})
self.interface.q.put_nowait({'block_height': 6, 'mock': {'binary':1,'check':lambda x: True, 'connect': lambda x: True}})
self.interface.q.put_nowait({'block_height': 7, 'mock': {CRM.BACKWARD:1,'check': lambda x: False, 'connect': mock_connect, 'fork': self.mock_fork}})
self.interface.q.put_nowait({'block_height': 2, 'mock': {CRM.BACKWARD:1,'check':lambda x: True, 'connect': lambda x: False}})
self.interface.q.put_nowait({'block_height': 4, 'mock': {CRM.BINARY:1,'check':lambda x: True, 'connect': lambda x: True}})
self.interface.q.put_nowait({'block_height': 5, 'mock': {CRM.BINARY:1,'check':lambda x: True, 'connect': lambda x: True}})
self.interface.q.put_nowait({'block_height': 6, 'mock': {CRM.BINARY:1,'check':lambda x: True, 'connect': lambda x: True}})
ifa = self.interface
res = await ifa.sync_until(8, next_height=7)
self.assertEqual(('fork', 8), res)
self.assertEqual((CRM.FORK, 8), res)
self.assertEqual(self.interface.q.qsize(), 0)
async def test_can_connect_during_backward(self):
blockchain.blockchains = {}
self.interface.q.put_nowait({'block_height': 8, 'mock': {'catchup':1, 'check': lambda x: False, 'connect': lambda x: False}})
self.interface.q.put_nowait({'block_height': 8, 'mock': {CRM.CATCHUP:1, 'check': lambda x: False, 'connect': lambda x: False}})
def mock_connect(height):
return height == 2
self.interface.q.put_nowait({'block_height': 7, 'mock': {'backward':1, 'check': lambda x: False, 'connect': mock_connect, 'fork': self.mock_fork}})
self.interface.q.put_nowait({'block_height': 2, 'mock': {'backward':1, 'check': lambda x: False, 'connect': mock_connect, 'fork': self.mock_fork}})
self.interface.q.put_nowait({'block_height': 3, 'mock': {'catchup':1, 'check': lambda x: False, 'connect': lambda x: True}})
self.interface.q.put_nowait({'block_height': 4, 'mock': {'catchup':1, 'check': lambda x: False, 'connect': lambda x: True}})
self.interface.q.put_nowait({'block_height': 7, 'mock': {CRM.BACKWARD:1, 'check': lambda x: False, 'connect': mock_connect, 'fork': self.mock_fork}})
self.interface.q.put_nowait({'block_height': 2, 'mock': {CRM.BACKWARD:1, 'check': lambda x: False, 'connect': mock_connect, 'fork': self.mock_fork}})
self.interface.q.put_nowait({'block_height': 3, 'mock': {CRM.CATCHUP:1, 'check': lambda x: False, 'connect': lambda x: True}})
self.interface.q.put_nowait({'block_height': 4, 'mock': {CRM.CATCHUP:1, 'check': lambda x: False, 'connect': lambda x: True}})
ifa = self.interface
res = await ifa.sync_until(8, next_height=4)
self.assertEqual(('catchup', 5), res)
self.assertEqual((CRM.CATCHUP, 5), res)
self.assertEqual(self.interface.q.qsize(), 0)
def mock_fork(self, bad_header):
@@ -113,17 +116,17 @@ class TestNetwork(ElectrumTestCase):
async def test_chain_false_during_binary(self):
blockchain.blockchains = {}
self.interface.q.put_nowait({'block_height': 8, 'mock': {'catchup':1, 'check': lambda x: False, 'connect': lambda x: False}})
self.interface.q.put_nowait({'block_height': 8, 'mock': {CRM.CATCHUP:1, 'check': lambda x: False, 'connect': lambda x: False}})
mock_connect = lambda height: height == 3
self.interface.q.put_nowait({'block_height': 7, 'mock': {'backward':1, 'check': lambda x: False, 'connect': mock_connect}})
self.interface.q.put_nowait({'block_height': 2, 'mock': {'backward':1, 'check': lambda x: True, 'connect': mock_connect}})
self.interface.q.put_nowait({'block_height': 4, 'mock': {'binary':1, 'check': lambda x: False, 'fork': self.mock_fork, 'connect': mock_connect}})
self.interface.q.put_nowait({'block_height': 3, 'mock': {'binary':1, 'check': lambda x: True, 'connect': lambda x: True}})
self.interface.q.put_nowait({'block_height': 5, 'mock': {'catchup':1, 'check': lambda x: False, 'connect': lambda x: True}})
self.interface.q.put_nowait({'block_height': 6, 'mock': {'catchup':1, 'check': lambda x: False, 'connect': lambda x: True}})
self.interface.q.put_nowait({'block_height': 7, 'mock': {CRM.BACKWARD:1, 'check': lambda x: False, 'connect': mock_connect}})
self.interface.q.put_nowait({'block_height': 2, 'mock': {CRM.BACKWARD:1, 'check': lambda x: True, 'connect': mock_connect}})
self.interface.q.put_nowait({'block_height': 4, 'mock': {CRM.BINARY:1, 'check': lambda x: False, 'fork': self.mock_fork, 'connect': mock_connect}})
self.interface.q.put_nowait({'block_height': 3, 'mock': {CRM.BINARY:1, 'check': lambda x: True, 'connect': lambda x: True}})
self.interface.q.put_nowait({'block_height': 5, 'mock': {CRM.CATCHUP:1, 'check': lambda x: False, 'connect': lambda x: True}})
self.interface.q.put_nowait({'block_height': 6, 'mock': {CRM.CATCHUP:1, 'check': lambda x: False, 'connect': lambda x: True}})
ifa = self.interface
res = await ifa.sync_until(8, next_height=6)
self.assertEqual(('catchup', 7), res)
self.assertEqual((CRM.CATCHUP, 7), res)
self.assertEqual(self.interface.q.qsize(), 0)