1
0

Merge pull request #9926 from SomberNight/202506_iface_headers3

interface: faster chain resolution: add headers_cache
This commit is contained in:
ghost43
2025-06-09 15:37:41 +00:00
committed by GitHub
3 changed files with 80 additions and 26 deletions

View File

@@ -445,7 +445,7 @@ class Blockchain(Logger):
raise FileNotFoundError('Cannot find headers file but headers_dir is there. Should be at {}'.format(path)) raise FileNotFoundError('Cannot find headers file but headers_dir is there. Should be at {}'.format(path))
@with_lock @with_lock
def write(self, data: bytes, offset: int, truncate: bool=True) -> None: def write(self, data: bytes, offset: int, truncate: bool = True, *, fsync: bool = True) -> None:
filename = self.path() filename = self.path()
self.assert_headers_file_available(filename) self.assert_headers_file_available(filename)
with open(filename, 'rb+') as f: with open(filename, 'rb+') as f:
@@ -454,8 +454,9 @@ class Blockchain(Logger):
f.truncate() f.truncate()
f.seek(offset) f.seek(offset)
f.write(data) f.write(data)
f.flush() if fsync:
os.fsync(f.fileno()) f.flush()
os.fsync(f.fileno())
self.update_size() self.update_size()
@with_lock @with_lock
@@ -465,7 +466,8 @@ class Blockchain(Logger):
# headers are only _appended_ to the end: # headers are only _appended_ to the end:
assert delta == self.size(), (delta, self.size()) assert delta == self.size(), (delta, self.size())
assert len(data) == HEADER_SIZE assert len(data) == HEADER_SIZE
self.write(data, delta*HEADER_SIZE) # note: we don't fsync, to improve perf. losing headers at end of file is ok.
self.write(data, delta*HEADER_SIZE, fsync=False)
self.swap_with_parent() self.swap_with_parent()
@with_lock @with_lock

View File

@@ -525,6 +525,8 @@ class Interface(Logger):
self.tip_header = None # type: Optional[dict] self.tip_header = None # type: Optional[dict]
self.tip = 0 self.tip = 0
self._headers_cache = {} # type: Dict[int, bytes]
self.fee_estimates_eta = {} # type: Dict[int, int] self.fee_estimates_eta = {} # type: Dict[int, int]
# Dump network messages (only for this interface). Set at runtime from the console. # Dump network messages (only for this interface). Set at runtime from the console.
@@ -756,16 +758,41 @@ class Interface(Logger):
raise ErrorSSLCertFingerprintMismatch('Refusing to connect to server due to cert fingerprint mismatch') raise ErrorSSLCertFingerprintMismatch('Refusing to connect to server due to cert fingerprint mismatch')
self.logger.info("cert fingerprint verification passed") self.logger.info("cert fingerprint verification passed")
async def _maybe_warm_headers_cache(self, *, from_height: int, to_height: int, mode: ChainResolutionMode) -> None:
"""Populate header cache for block heights in range [from_height, to_height]."""
assert from_height <= to_height, (from_height, to_height)
assert to_height - from_height < MAX_NUM_HEADERS_PER_REQUEST
if all(height in self._headers_cache for height in range(from_height, to_height+1)):
# cache already has all requested headers
return
# use lower timeout as we usually have network.bhi_lock here
timeout = self.network.get_network_timeout_seconds(NetworkTimeout.Urgent)
count = to_height - from_height + 1
headers = await self.get_block_headers(start_height=from_height, count=count, timeout=timeout, mode=mode)
for idx, raw_header in enumerate(headers):
header_height = from_height + idx
self._headers_cache[header_height] = raw_header
async def get_block_header(self, height: int, *, mode: ChainResolutionMode) -> dict: async def get_block_header(self, height: int, *, mode: ChainResolutionMode) -> dict:
if not is_non_negative_integer(height): if not is_non_negative_integer(height):
raise Exception(f"{repr(height)} is not a block height") raise Exception(f"{repr(height)} is not a block height")
self.logger.info(f'requesting block header {height} in {mode=}') #self.logger.debug(f'get_block_header() {height} in {mode=}')
# use lower timeout as we usually have network.bhi_lock here # use lower timeout as we usually have network.bhi_lock here
timeout = self.network.get_network_timeout_seconds(NetworkTimeout.Urgent) timeout = self.network.get_network_timeout_seconds(NetworkTimeout.Urgent)
if raw_header := self._headers_cache.get(height):
return blockchain.deserialize_header(raw_header, height)
self.logger.info(f'requesting block header {height} in {mode=}')
res = await self.session.send_request('blockchain.block.header', [height], timeout=timeout) res = await self.session.send_request('blockchain.block.header', [height], timeout=timeout)
return blockchain.deserialize_header(bytes.fromhex(res), height) return blockchain.deserialize_header(bytes.fromhex(res), height)
async def get_block_headers(self, *, start_height: int, count: int) -> Sequence[bytes]: async def get_block_headers(
self,
*,
start_height: int,
count: int,
timeout=None,
mode: Optional[ChainResolutionMode] = None,
) -> Sequence[bytes]:
"""Request a number of consecutive block headers, starting at `start_height`. """Request a number of consecutive block headers, starting at `start_height`.
`count` is the num of requested headers, BUT note the server might return fewer than this `count` is the num of requested headers, BUT note the server might return fewer than this
(if range would extend beyond its tip). (if range would extend beyond its tip).
@@ -775,8 +802,11 @@ class Interface(Logger):
raise Exception(f"{repr(start_height)} is not a block height") raise Exception(f"{repr(start_height)} is not a block height")
if not is_non_negative_integer(count) or not (0 < count <= MAX_NUM_HEADERS_PER_REQUEST): if not is_non_negative_integer(count) or not (0 < count <= MAX_NUM_HEADERS_PER_REQUEST):
raise Exception(f"{repr(count)} not an int in range ]0, {MAX_NUM_HEADERS_PER_REQUEST}]") raise Exception(f"{repr(count)} not an int in range ]0, {MAX_NUM_HEADERS_PER_REQUEST}]")
self.logger.info(f'requesting block headers: [{start_height}, {start_height+count-1}], {count=}') self.logger.info(
res = await self.session.send_request('blockchain.block.headers', [start_height, count]) f"requesting block headers: [{start_height}, {start_height+count-1}], {count=}"
+ (f" (in {mode=})" if mode is not None else "")
)
res = await self.session.send_request('blockchain.block.headers', [start_height, count], timeout=timeout)
# check response # check response
assert_dict_contains_field(res, field_name='count') assert_dict_contains_field(res, field_name='count')
assert_dict_contains_field(res, field_name='hex') assert_dict_contains_field(res, field_name='hex')
@@ -938,17 +968,23 @@ class Interface(Logger):
item = await header_queue.get() item = await header_queue.get()
raw_header = item[0] raw_header = item[0]
height = raw_header['height'] height = raw_header['height']
header = blockchain.deserialize_header(bfh(raw_header['hex']), height) header_bytes = bfh(raw_header['hex'])
self.tip_header = header header_dict = blockchain.deserialize_header(header_bytes, height)
self.tip_header = header_dict
self.tip = height self.tip = height
if self.tip < constants.net.max_checkpoint(): if self.tip < constants.net.max_checkpoint():
raise GracefulDisconnect( raise GracefulDisconnect(
f"server tip below max checkpoint. ({self.tip} < {constants.net.max_checkpoint()})") f"server tip below max checkpoint. ({self.tip} < {constants.net.max_checkpoint()})")
self._mark_ready() self._mark_ready()
blockchain_updated = await self._process_header_at_tip() self._headers_cache.clear() # tip changed, so assume anything could have happened with chain
self._headers_cache[height] = header_bytes
try:
blockchain_updated = await self._process_header_at_tip()
finally:
self._headers_cache.clear() # to reduce memory usage
# header processing done # header processing done
if self.is_main_server(): if self.is_main_server() or blockchain_updated:
self.logger.info(f"new chain tip on main interface. {height=}") self.logger.info(f"new chain tip. {height=}")
if blockchain_updated: if blockchain_updated:
util.trigger_callback('blockchain_updated') util.trigger_callback('blockchain_updated')
util.trigger_callback('network_updated') util.trigger_callback('network_updated')
@@ -966,36 +1002,40 @@ class Interface(Logger):
if self.blockchain.height() >= height and self.blockchain.check_header(header): if self.blockchain.height() >= height and self.blockchain.check_header(header):
# another interface amended the blockchain # another interface amended the blockchain
return False return False
_, height = await self.step(height, header=header) await self.sync_until(height)
# in the simple case, height == self.tip+1
if height <= self.tip:
await self.sync_until(height)
return True return True
async def sync_until( async def sync_until(
self, self,
height: int, height: int,
*, *,
next_height: Optional[int] = None, next_height: Optional[int] = None, # sync target. typically the tip, except in unit tests
) -> Tuple[ChainResolutionMode, int]: ) -> Tuple[ChainResolutionMode, int]:
if next_height is None: if next_height is None:
next_height = self.tip next_height = self.tip
last = None # type: Optional[ChainResolutionMode] last = None # type: Optional[ChainResolutionMode]
while last is None or height <= next_height: while last is None or height <= next_height:
prev_last, prev_height = last, height prev_last, prev_height = last, height
if next_height > height + 10: # TODO make smarter. the protocol allows asking for n headers if next_height > height + 144:
# We are far from the tip.
# It is more efficient to process headers in large batches (CPU/disk_usage/logging).
# (but this wastes a little bandwidth, if we are not on a chunk boundary)
# TODO we should request (some) chunks concurrently. would help when we are many chunks behind
could_connect, num_headers = await self.request_chunk(height, tip=next_height) could_connect, num_headers = await self.request_chunk(height, tip=next_height)
if not could_connect: if not could_connect:
if height <= constants.net.max_checkpoint(): if height <= constants.net.max_checkpoint():
raise GracefulDisconnect('server chain conflicts with checkpoints or genesis') raise GracefulDisconnect('server chain conflicts with checkpoints or genesis')
last, height = await self.step(height) last, height = await self.step(height)
continue continue
# report progress to gui/etc
util.trigger_callback('blockchain_updated') util.trigger_callback('blockchain_updated')
util.trigger_callback('network_updated') util.trigger_callback('network_updated')
height = (height // CHUNK_SIZE * CHUNK_SIZE) + num_headers height = (height // CHUNK_SIZE * CHUNK_SIZE) + num_headers
assert height <= next_height+1, (height, self.tip) assert height <= next_height+1, (height, self.tip)
last = ChainResolutionMode.CATCHUP last = ChainResolutionMode.CATCHUP
else: else:
# We are close to the tip, so process headers one-by-one.
# (note: due to headers_cache, to save network latency, this can still batch-request headers)
last, height = await self.step(height) last, height = await self.step(height)
assert (prev_last, prev_height) != (last, height), 'had to prevent infinite loop in interface.sync_until' assert (prev_last, prev_height) != (last, height), 'had to prevent infinite loop in interface.sync_until'
return last, height return last, height
@@ -1003,12 +1043,14 @@ class Interface(Logger):
async def step( async def step(
self, self,
height: int, height: int,
*,
header: Optional[dict] = None, # at 'height'
) -> Tuple[ChainResolutionMode, int]: ) -> Tuple[ChainResolutionMode, int]:
assert 0 <= height <= self.tip, (height, self.tip) assert 0 <= height <= self.tip, (height, self.tip)
if header is None: await self._maybe_warm_headers_cache(
header = await self.get_block_header(height, mode=ChainResolutionMode.CATCHUP) from_height=height,
to_height=min(self.tip, height+MAX_NUM_HEADERS_PER_REQUEST-1),
mode=ChainResolutionMode.CATCHUP,
)
header = await self.get_block_header(height, mode=ChainResolutionMode.CATCHUP)
chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header) chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
if chain: if chain:
@@ -1027,7 +1069,6 @@ class Interface(Logger):
can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height) can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height)
assert chain or can_connect assert chain or can_connect
if can_connect: if can_connect:
self.logger.info(f"new block: {height=}")
height += 1 height += 1
if isinstance(can_connect, Blockchain): # not when mocking if isinstance(can_connect, Blockchain): # not when mocking
self.blockchain = can_connect self.blockchain = can_connect
@@ -1050,9 +1091,12 @@ class Interface(Logger):
self.blockchain = chain if isinstance(chain, Blockchain) else self.blockchain self.blockchain = chain if isinstance(chain, Blockchain) else self.blockchain
good = height good = height
while True: while True:
assert good < bad, (good, bad) assert 0 <= good < bad, (good, bad)
height = (good + bad) // 2 height = (good + bad) // 2
self.logger.info(f"binary step. good {good}, bad {bad}, height {height}") self.logger.info(f"binary step. good {good}, bad {bad}, height {height}")
if bad - good + 1 <= MAX_NUM_HEADERS_PER_REQUEST: # if interval is small, trade some bandwidth for lower latency
await self._maybe_warm_headers_cache(
from_height=good, to_height=bad, mode=ChainResolutionMode.BINARY)
header = await self.get_block_header(height, mode=ChainResolutionMode.BINARY) header = await self.get_block_header(height, mode=ChainResolutionMode.BINARY)
chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header) chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
if chain: if chain:
@@ -1127,9 +1171,14 @@ class Interface(Logger):
with blockchain.blockchains_lock: chains = list(blockchain.blockchains.values()) with blockchain.blockchains_lock: chains = list(blockchain.blockchains.values())
local_max = max([0] + [x.height() for x in chains]) if 'mock' not in header else float('inf') local_max = max([0] + [x.height() for x in chains]) if 'mock' not in header else float('inf')
height = min(local_max + 1, height - 1) height = min(local_max + 1, height - 1)
assert height >= 0
await self._maybe_warm_headers_cache(
from_height=max(0, height-10), to_height=height, mode=ChainResolutionMode.BACKWARD)
while await iterate(): while await iterate():
bad, bad_header = height, header bad, bad_header = height, header
delta = self.tip - height delta = self.tip - height # FIXME why compared to tip? would be easier to cache if delta started at 1
assert delta > 0, delta assert delta > 0, delta
height = self.tip - 2 * delta height = self.tip - 2 * delta

View File

@@ -46,6 +46,9 @@ class MockInterface(Interface):
async def run(self): async def run(self):
return return
async def _maybe_warm_headers_cache(self, *args, **kwargs):
return
class TestNetwork(ElectrumTestCase): class TestNetwork(ElectrumTestCase):