From 27599ac53798504d7c5393795cac5544520e4364 Mon Sep 17 00:00:00 2001 From: SomberNight Date: Fri, 6 Jun 2025 16:42:15 +0000 Subject: [PATCH] interface: small clean-up. intro ChainResolutionMode. - type hints - minor API changes - no functional changes --- electrum/interface.py | 111 +++++++++++++++++++++++++++++------------ electrum/lnverifier.py | 2 +- electrum/network.py | 8 ++- electrum/verifier.py | 2 +- tests/test_network.py | 75 +++++++++++++++------------- 5 files changed, 127 insertions(+), 71 deletions(-) diff --git a/electrum/interface.py b/electrum/interface.py index 6ecf7850b..faa97880d 100644 --- a/electrum/interface.py +++ b/electrum/interface.py @@ -38,6 +38,7 @@ import logging import hashlib import functools import random +import enum import aiorpcx from aiorpcx import RPCSession, Notification, NetAddress, NewlineFramer @@ -132,6 +133,14 @@ def assert_list_or_tuple(val: Any) -> None: raise RequestCorrupted(f'{val!r} should be a list or tuple') +class ChainResolutionMode(enum.Enum): + CATCHUP = enum.auto() + BACKWARD = enum.auto() + BINARY = enum.auto() + FORK = enum.auto() + NO_FORK = enum.auto() + + class NotificationSession(RPCSession): def __init__(self, *args, interface: 'Interface', **kwargs): @@ -510,7 +519,7 @@ class Interface(Logger): # Note that these values are updated before they are verified. # Especially during initial header sync, verification can take a long time. # Failing verification will get the interface closed. - self.tip_header = None + self.tip_header = None # type: Optional[dict] self.tip = 0 self.fee_estimates_eta = {} # type: Dict[int, int] @@ -543,13 +552,13 @@ class Interface(Logger): def __str__(self): return f"" - async def is_server_ca_signed(self, ca_ssl_context): + async def is_server_ca_signed(self, ca_ssl_context: ssl.SSLContext) -> bool: """Given a CA enforcing SSL context, returns True if the connection can be established. Returns False if the server has a self-signed certificate but otherwise is okay. Any other failures raise. """ try: - await self.open_session(ca_ssl_context, exit_early=True) + await self.open_session(ssl_context=ca_ssl_context, exit_early=True) except ConnectError as e: cause = e.__cause__ if (isinstance(cause, ssl.SSLCertVerificationError) @@ -562,7 +571,7 @@ class Interface(Logger): # Good. We will use this server as CA-signed. return True - async def _try_saving_ssl_cert_for_first_time(self, ca_ssl_context): + async def _try_saving_ssl_cert_for_first_time(self, ca_ssl_context: ssl.SSLContext) -> None: ca_signed = await self.is_server_ca_signed(ca_ssl_context) if ca_signed: if self._get_expected_fingerprint(): @@ -599,10 +608,10 @@ class Interface(Logger): self.logger.info(f"certificate has expired: {e}") os.unlink(self.cert_path) # delete pinned cert only in this case return False - self._verify_certificate_fingerprint(bytearray(b)) + self._verify_certificate_fingerprint(bytes(b)) return True - async def _get_ssl_context(self): + async def _get_ssl_context(self) -> Optional[ssl.SSLContext]: if self.protocol != 's': # using plaintext TCP return None @@ -658,7 +667,7 @@ class Interface(Logger): self.logger.info(f'disconnecting due to: {repr(e)}') return try: - await self.open_session(ssl_context) + await self.open_session(ssl_context=ssl_context) except (asyncio.CancelledError, ConnectError, aiorpcx.socks.SOCKSError) as e: # make SSL errors for main interface more visible (to help servers ops debug cert pinning issues) if (isinstance(e, ConnectError) and isinstance(e.__cause__, ssl.SSLError) @@ -731,8 +740,9 @@ class Interface(Logger): def _get_expected_fingerprint(self) -> Optional[str]: if self.is_main_server(): return self.network.config.NETWORK_SERVERFINGERPRINT + return None - def _verify_certificate_fingerprint(self, certificate): + def _verify_certificate_fingerprint(self, certificate: bytes) -> None: expected_fingerprint = self._get_expected_fingerprint() if not expected_fingerprint: return @@ -743,21 +753,27 @@ class Interface(Logger): raise ErrorSSLCertFingerprintMismatch('Refusing to connect to server due to cert fingerprint mismatch') self.logger.info("cert fingerprint verification passed") - async def get_block_header(self, height, assert_mode): + async def get_block_header(self, height: int, *, mode: ChainResolutionMode) -> dict: if not is_non_negative_integer(height): raise Exception(f"{repr(height)} is not a block height") - self.logger.info(f'requesting block header {height} in mode {assert_mode}') + self.logger.info(f'requesting block header {height} in {mode=}') # use lower timeout as we usually have network.bhi_lock here timeout = self.network.get_network_timeout_seconds(NetworkTimeout.Urgent) res = await self.session.send_request('blockchain.block.header', [height], timeout=timeout) return blockchain.deserialize_header(bytes.fromhex(res), height) - async def request_chunk(self, height: int, tip=None, *, can_return_early=False): + async def request_chunk( + self, + height: int, + *, + tip: Optional[int] = None, + can_return_early: bool = False, + ) -> Optional[Tuple[bool, int]]: if not is_non_negative_integer(height): raise Exception(f"{repr(height)} is not a block height") index = height // 2016 if can_return_early and index in self._requested_chunks: - return + return None self.logger.info(f"requesting chunk from height {height}") size = 2016 if tip is not None: @@ -790,12 +806,17 @@ class Interface(Logger): return (self.network.interface == self or self.network.interface is None and self.network.default_server == self.server) - async def open_session(self, sslc, exit_early=False): + async def open_session( + self, + *, + ssl_context: Optional[ssl.SSLContext], + exit_early: bool = False, + ): session_factory = lambda *args, iface=self, **kwargs: NotificationSession(*args, **kwargs, interface=iface) async with _RSClient( session_factory=session_factory, host=self.host, port=self.port, - ssl=sslc, + ssl=ssl_context, proxy=self.proxy, transport=PaddedRSTransport, ) as session: @@ -918,20 +939,25 @@ class Interface(Logger): if self.blockchain.height() >= height and self.blockchain.check_header(header): # another interface amended the blockchain return False - _, height = await self.step(height, header) + _, height = await self.step(height, header=header) # in the simple case, height == self.tip+1 if height <= self.tip: await self.sync_until(height) return True - async def sync_until(self, height, next_height=None): + async def sync_until( + self, + height: int, + *, + next_height: Optional[int] = None, + ) -> Tuple[ChainResolutionMode, int]: if next_height is None: next_height = self.tip - last = None + last = None # type: Optional[ChainResolutionMode] while last is None or height <= next_height: prev_last, prev_height = last, height - if next_height > height + 10: - could_connect, num_headers = await self.request_chunk(height, next_height) + if next_height > height + 10: # TODO make smarter. the protocol allows asking for n headers + could_connect, num_headers = await self.request_chunk(height, tip=next_height) if not could_connect: if height <= constants.net.max_checkpoint(): raise GracefulDisconnect('server chain conflicts with checkpoints or genesis') @@ -941,16 +967,21 @@ class Interface(Logger): util.trigger_callback('network_updated') height = (height // 2016 * 2016) + num_headers assert height <= next_height+1, (height, self.tip) - last = 'catchup' + last = ChainResolutionMode.CATCHUP else: last, height = await self.step(height) assert (prev_last, prev_height) != (last, height), 'had to prevent infinite loop in interface.sync_until' return last, height - async def step(self, height, header=None): + async def step( + self, + height: int, + *, + header: Optional[dict] = None, # at 'height' + ) -> Tuple[ChainResolutionMode, int]: assert 0 <= height <= self.tip, (height, self.tip) if header is None: - header = await self.get_block_header(height, 'catchup') + header = await self.get_block_header(height, mode=ChainResolutionMode.CATCHUP) chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header) if chain: @@ -959,12 +990,12 @@ class Interface(Logger): # we might know the blockhash (enough for check_header) but # not have the header itself. e.g. regtest chain with only genesis. # this situation resolves itself on the next block - return 'catchup', height+1 + return ChainResolutionMode.CATCHUP, height+1 can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height) if not can_connect: self.logger.info(f"can't connect new block: {height=}") - height, header, bad, bad_header = await self._search_headers_backwards(height, header) + height, header, bad, bad_header = await self._search_headers_backwards(height, header=header) chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header) can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height) assert chain or can_connect @@ -974,12 +1005,18 @@ class Interface(Logger): if isinstance(can_connect, Blockchain): # not when mocking self.blockchain = can_connect self.blockchain.save_header(header) - return 'catchup', height + return ChainResolutionMode.CATCHUP, height good, bad, bad_header = await self._search_headers_binary(height, bad, bad_header, chain) return await self._resolve_potential_chain_fork_given_forkpoint(good, bad, bad_header) - async def _search_headers_binary(self, height, bad, bad_header, chain): + async def _search_headers_binary( + self, + height: int, + bad: int, + bad_header: dict, + chain: Optional[Blockchain], + ) -> Tuple[int, int, dict]: assert bad == bad_header['block_height'] _assert_header_does_not_check_against_any_chain(bad_header) @@ -989,7 +1026,7 @@ class Interface(Logger): assert good < bad, (good, bad) height = (good + bad) // 2 self.logger.info(f"binary step. good {good}, bad {bad}, height {height}") - header = await self.get_block_header(height, 'binary') + header = await self.get_block_header(height, mode=ChainResolutionMode.BINARY) chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header) if chain: self.blockchain = chain if isinstance(chain, Blockchain) else self.blockchain @@ -1009,7 +1046,12 @@ class Interface(Logger): self.logger.info(f"binary search exited. good {good}, bad {bad}") return good, bad, bad_header - async def _resolve_potential_chain_fork_given_forkpoint(self, good, bad, bad_header): + async def _resolve_potential_chain_fork_given_forkpoint( + self, + good: int, + bad: int, + bad_header: dict, + ) -> Tuple[ChainResolutionMode, int]: assert good + 1 == bad assert bad == bad_header['block_height'] _assert_header_does_not_check_against_any_chain(bad_header) @@ -1021,7 +1063,7 @@ class Interface(Logger): if bh == good: height = good + 1 self.logger.info(f"catching up from {height}") - return 'no_fork', height + return ChainResolutionMode.NO_FORK, height # this is a new fork we don't yet have height = bad + 1 @@ -1030,16 +1072,21 @@ class Interface(Logger): b = forkfun(bad_header) # type: Blockchain self.blockchain = b assert b.forkpoint == bad - return 'fork', height + return ChainResolutionMode.FORK, height - async def _search_headers_backwards(self, height, header): + async def _search_headers_backwards( + self, + height: int, + *, + header: dict, + ) -> Tuple[int, dict, int, dict]: async def iterate(): nonlocal height, header checkp = False if height <= constants.net.max_checkpoint(): height = constants.net.max_checkpoint() checkp = True - header = await self.get_block_header(height, 'backward') + header = await self.get_block_header(height, mode=ChainResolutionMode.BACKWARD) chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header) can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height) if chain or can_connect: diff --git a/electrum/lnverifier.py b/electrum/lnverifier.py index 7d6b22e46..71d6606aa 100644 --- a/electrum/lnverifier.py +++ b/electrum/lnverifier.py @@ -102,7 +102,7 @@ class LNChannelVerifier(NetworkJobOnDefaultServer): header = blockchain.read_header(block_height) if header is None: if block_height < constants.net.max_checkpoint(): - await self.taskgroup.spawn(self.network.request_chunk(block_height, None, can_return_early=True)) + await self.taskgroup.spawn(self.network.request_chunk(block_height, can_return_early=True)) continue self.started_verifying_channel.add(short_channel_id) await self.taskgroup.spawn(self.verify_channel(block_height, short_channel_id)) diff --git a/electrum/network.py b/electrum/network.py index 74abd009f..e1b41b8a5 100644 --- a/electrum/network.py +++ b/electrum/network.py @@ -1311,7 +1311,13 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): @best_effort_reliable @catch_server_exceptions - async def request_chunk(self, height: int, tip=None, *, can_return_early=False): + async def request_chunk( + self, + height: int, + *, + tip: Optional[int] = None, + can_return_early: bool = False, + ) -> Optional[Tuple[bool, int]]: if self.interface is None: # handled by best_effort_reliable raise RequestTimedOut() return await self.interface.request_chunk(height, tip=tip, can_return_early=can_return_early) diff --git a/electrum/verifier.py b/electrum/verifier.py index ffb8d7ad6..f5f96ce3d 100644 --- a/electrum/verifier.py +++ b/electrum/verifier.py @@ -88,7 +88,7 @@ class SPV(NetworkJobOnDefaultServer): if header is None: if tx_height < constants.net.max_checkpoint(): # FIXME these requests are not counted (self._requests_sent += 1) - await self.taskgroup.spawn(self.interface.request_chunk(tx_height, None, can_return_early=True)) + await self.taskgroup.spawn(self.interface.request_chunk(tx_height, can_return_early=True)) continue # request now self.logger.info(f'requested merkle {tx_hash}') diff --git a/tests/test_network.py b/tests/test_network.py index d2ccdf78e..a9fdf2621 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -5,7 +5,7 @@ import unittest from electrum import constants from electrum.simple_config import SimpleConfig from electrum import blockchain -from electrum.interface import Interface, ServerAddr +from electrum.interface import Interface, ServerAddr, ChainResolutionMode from electrum.crypto import sha256 from electrum.util import OldTaskGroup from electrum import util @@ -13,18 +13,21 @@ from electrum import util from . import ElectrumTestCase +CRM = ChainResolutionMode + + class MockNetwork: - def __init__(self): + def __init__(self, config: SimpleConfig): + self.config = config self.asyncio_loop = util.get_asyncio_loop() self.taskgroup = OldTaskGroup() self.proxy = None class MockInterface(Interface): - def __init__(self, config): + def __init__(self, config: SimpleConfig): self.config = config - network = MockNetwork() - network.config = config + network = MockNetwork(config) super().__init__(network=network, server=ServerAddr.from_str('mock-server:50000:t')) self.q = asyncio.Queue() self.blockchain = blockchain.Blockchain(config=self.config, forkpoint=0, @@ -32,12 +35,12 @@ class MockInterface(Interface): self.tip = 12 self.blockchain._size = self.tip + 1 - async def get_block_header(self, height, assert_mode): - assert self.q.qsize() > 0, (height, assert_mode) + async def get_block_header(self, height: int, *, mode: ChainResolutionMode) -> dict: + assert self.q.qsize() > 0, (height, mode) item = await self.q.get() print("step with height", height, item) assert item['block_height'] == height, (item['block_height'], height) - assert assert_mode in item['mock'], (assert_mode, item) + assert mode in item['mock'], (mode, item) return item async def run(self): @@ -63,46 +66,46 @@ class TestNetwork(ElectrumTestCase): async def test_fork_noconflict(self): blockchain.blockchains = {} - self.interface.q.put_nowait({'block_height': 8, 'mock': {'catchup':1, 'check': lambda x: False, 'connect': lambda x: False}}) + self.interface.q.put_nowait({'block_height': 8, 'mock': {CRM.CATCHUP:1, 'check': lambda x: False, 'connect': lambda x: False}}) def mock_connect(height): return height == 6 - self.interface.q.put_nowait({'block_height': 7, 'mock': {'backward':1,'check': lambda x: False, 'connect': mock_connect, 'fork': self.mock_fork}}) - self.interface.q.put_nowait({'block_height': 2, 'mock': {'backward':1,'check':lambda x: True, 'connect': lambda x: False}}) - self.interface.q.put_nowait({'block_height': 4, 'mock': {'binary':1,'check':lambda x: True, 'connect': lambda x: True}}) - self.interface.q.put_nowait({'block_height': 5, 'mock': {'binary':1,'check':lambda x: True, 'connect': lambda x: True}}) - self.interface.q.put_nowait({'block_height': 6, 'mock': {'binary':1,'check':lambda x: True, 'connect': lambda x: True}}) + self.interface.q.put_nowait({'block_height': 7, 'mock': {CRM.BACKWARD:1,'check': lambda x: False, 'connect': mock_connect, 'fork': self.mock_fork}}) + self.interface.q.put_nowait({'block_height': 2, 'mock': {CRM.BACKWARD:1,'check':lambda x: True, 'connect': lambda x: False}}) + self.interface.q.put_nowait({'block_height': 4, 'mock': {CRM.BINARY:1,'check':lambda x: True, 'connect': lambda x: True}}) + self.interface.q.put_nowait({'block_height': 5, 'mock': {CRM.BINARY:1,'check':lambda x: True, 'connect': lambda x: True}}) + self.interface.q.put_nowait({'block_height': 6, 'mock': {CRM.BINARY:1,'check':lambda x: True, 'connect': lambda x: True}}) ifa = self.interface res = await ifa.sync_until(8, next_height=7) - self.assertEqual(('fork', 8), res) + self.assertEqual((CRM.FORK, 8), res) self.assertEqual(self.interface.q.qsize(), 0) async def test_fork_conflict(self): blockchain.blockchains = {7: {'check': lambda bad_header: False}} - self.interface.q.put_nowait({'block_height': 8, 'mock': {'catchup':1, 'check': lambda x: False, 'connect': lambda x: False}}) + self.interface.q.put_nowait({'block_height': 8, 'mock': {CRM.CATCHUP:1, 'check': lambda x: False, 'connect': lambda x: False}}) def mock_connect(height): return height == 6 - self.interface.q.put_nowait({'block_height': 7, 'mock': {'backward':1,'check': lambda x: False, 'connect': mock_connect, 'fork': self.mock_fork}}) - self.interface.q.put_nowait({'block_height': 2, 'mock': {'backward':1,'check':lambda x: True, 'connect': lambda x: False}}) - self.interface.q.put_nowait({'block_height': 4, 'mock': {'binary':1,'check':lambda x: True, 'connect': lambda x: True}}) - self.interface.q.put_nowait({'block_height': 5, 'mock': {'binary':1,'check':lambda x: True, 'connect': lambda x: True}}) - self.interface.q.put_nowait({'block_height': 6, 'mock': {'binary':1,'check':lambda x: True, 'connect': lambda x: True}}) + self.interface.q.put_nowait({'block_height': 7, 'mock': {CRM.BACKWARD:1,'check': lambda x: False, 'connect': mock_connect, 'fork': self.mock_fork}}) + self.interface.q.put_nowait({'block_height': 2, 'mock': {CRM.BACKWARD:1,'check':lambda x: True, 'connect': lambda x: False}}) + self.interface.q.put_nowait({'block_height': 4, 'mock': {CRM.BINARY:1,'check':lambda x: True, 'connect': lambda x: True}}) + self.interface.q.put_nowait({'block_height': 5, 'mock': {CRM.BINARY:1,'check':lambda x: True, 'connect': lambda x: True}}) + self.interface.q.put_nowait({'block_height': 6, 'mock': {CRM.BINARY:1,'check':lambda x: True, 'connect': lambda x: True}}) ifa = self.interface res = await ifa.sync_until(8, next_height=7) - self.assertEqual(('fork', 8), res) + self.assertEqual((CRM.FORK, 8), res) self.assertEqual(self.interface.q.qsize(), 0) async def test_can_connect_during_backward(self): blockchain.blockchains = {} - self.interface.q.put_nowait({'block_height': 8, 'mock': {'catchup':1, 'check': lambda x: False, 'connect': lambda x: False}}) + self.interface.q.put_nowait({'block_height': 8, 'mock': {CRM.CATCHUP:1, 'check': lambda x: False, 'connect': lambda x: False}}) def mock_connect(height): return height == 2 - self.interface.q.put_nowait({'block_height': 7, 'mock': {'backward':1, 'check': lambda x: False, 'connect': mock_connect, 'fork': self.mock_fork}}) - self.interface.q.put_nowait({'block_height': 2, 'mock': {'backward':1, 'check': lambda x: False, 'connect': mock_connect, 'fork': self.mock_fork}}) - self.interface.q.put_nowait({'block_height': 3, 'mock': {'catchup':1, 'check': lambda x: False, 'connect': lambda x: True}}) - self.interface.q.put_nowait({'block_height': 4, 'mock': {'catchup':1, 'check': lambda x: False, 'connect': lambda x: True}}) + self.interface.q.put_nowait({'block_height': 7, 'mock': {CRM.BACKWARD:1, 'check': lambda x: False, 'connect': mock_connect, 'fork': self.mock_fork}}) + self.interface.q.put_nowait({'block_height': 2, 'mock': {CRM.BACKWARD:1, 'check': lambda x: False, 'connect': mock_connect, 'fork': self.mock_fork}}) + self.interface.q.put_nowait({'block_height': 3, 'mock': {CRM.CATCHUP:1, 'check': lambda x: False, 'connect': lambda x: True}}) + self.interface.q.put_nowait({'block_height': 4, 'mock': {CRM.CATCHUP:1, 'check': lambda x: False, 'connect': lambda x: True}}) ifa = self.interface res = await ifa.sync_until(8, next_height=4) - self.assertEqual(('catchup', 5), res) + self.assertEqual((CRM.CATCHUP, 5), res) self.assertEqual(self.interface.q.qsize(), 0) def mock_fork(self, bad_header): @@ -113,17 +116,17 @@ class TestNetwork(ElectrumTestCase): async def test_chain_false_during_binary(self): blockchain.blockchains = {} - self.interface.q.put_nowait({'block_height': 8, 'mock': {'catchup':1, 'check': lambda x: False, 'connect': lambda x: False}}) + self.interface.q.put_nowait({'block_height': 8, 'mock': {CRM.CATCHUP:1, 'check': lambda x: False, 'connect': lambda x: False}}) mock_connect = lambda height: height == 3 - self.interface.q.put_nowait({'block_height': 7, 'mock': {'backward':1, 'check': lambda x: False, 'connect': mock_connect}}) - self.interface.q.put_nowait({'block_height': 2, 'mock': {'backward':1, 'check': lambda x: True, 'connect': mock_connect}}) - self.interface.q.put_nowait({'block_height': 4, 'mock': {'binary':1, 'check': lambda x: False, 'fork': self.mock_fork, 'connect': mock_connect}}) - self.interface.q.put_nowait({'block_height': 3, 'mock': {'binary':1, 'check': lambda x: True, 'connect': lambda x: True}}) - self.interface.q.put_nowait({'block_height': 5, 'mock': {'catchup':1, 'check': lambda x: False, 'connect': lambda x: True}}) - self.interface.q.put_nowait({'block_height': 6, 'mock': {'catchup':1, 'check': lambda x: False, 'connect': lambda x: True}}) + self.interface.q.put_nowait({'block_height': 7, 'mock': {CRM.BACKWARD:1, 'check': lambda x: False, 'connect': mock_connect}}) + self.interface.q.put_nowait({'block_height': 2, 'mock': {CRM.BACKWARD:1, 'check': lambda x: True, 'connect': mock_connect}}) + self.interface.q.put_nowait({'block_height': 4, 'mock': {CRM.BINARY:1, 'check': lambda x: False, 'fork': self.mock_fork, 'connect': mock_connect}}) + self.interface.q.put_nowait({'block_height': 3, 'mock': {CRM.BINARY:1, 'check': lambda x: True, 'connect': lambda x: True}}) + self.interface.q.put_nowait({'block_height': 5, 'mock': {CRM.CATCHUP:1, 'check': lambda x: False, 'connect': lambda x: True}}) + self.interface.q.put_nowait({'block_height': 6, 'mock': {CRM.CATCHUP:1, 'check': lambda x: False, 'connect': lambda x: True}}) ifa = self.interface res = await ifa.sync_until(8, next_height=6) - self.assertEqual(('catchup', 7), res) + self.assertEqual((CRM.CATCHUP, 7), res) self.assertEqual(self.interface.q.qsize(), 0)