diff --git a/electrum/synchronizer.py b/electrum/synchronizer.py index 7907645b2..c231eab12 100644 --- a/electrum/synchronizer.py +++ b/electrum/synchronizer.py @@ -160,7 +160,18 @@ class Synchronizer(SynchronizerBase): and not self._stale_histories and self.status_queue.empty()) - async def _maybe_request_history_for_addr(self, addr: str) -> List[dict]: + async def _maybe_request_history_for_addr(self, addr: str, *, ann_status: Optional[str]) -> List[dict]: + # First opportunistically try to guess the addr history. Might save us network requests. + old_history = self.adb.db.get_addr_history(addr) + def guess_height(old_height: int) -> int: + if old_height in (0, -1,): + return self.interface.tip # maybe mempool tx got mined just now + return old_height + guessed_history = [(txid, guess_height(old_height)) for (txid, old_height) in old_history] + if history_status(guessed_history) == ann_status: + self.logger.debug(f"managed to guess new history for {addr}. won't call 'blockchain.scripthash.get_history'.") + return [{"height": height, "tx_hash": txid} for (txid, height) in guessed_history] + # request addr history from server sh = address_to_scripthash(addr) self._requests_sent += 1 async with self._network_request_semaphore: @@ -183,7 +194,7 @@ class Synchronizer(SynchronizerBase): self._stale_histories.pop(addr, asyncio.Future()).cancel() finally: self._handling_addr_statuses.discard(addr) - result = await self._maybe_request_history_for_addr(addr) + result = await self._maybe_request_history_for_addr(addr, ann_status=status) hist = list(map(lambda item: (item['tx_hash'], item['height']), result)) # tx_fees tx_fees = [(item['tx_hash'], item.get('fee')) for item in result] diff --git a/tests/test_interface.py b/tests/test_interface.py index 07b13338e..06388a17c 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -1,5 +1,6 @@ import asyncio import collections +from typing import Optional, Sequence, Iterable import aiorpcx from aiorpcx import RPCError @@ -12,8 +13,13 @@ from electrum.logging import Logger from electrum.simple_config import SimpleConfig from electrum.transaction import Transaction from electrum import constants +from electrum.wallet import Abstract_Wallet +from electrum.blockchain import Blockchain +from electrum.bitcoin import script_to_scripthash +from electrum.synchronizer import history_status from . import ElectrumTestCase +from . import restore_wallet_from_text__for_unittest class TestServerAddr(ElectrumTestCase): @@ -86,6 +92,10 @@ class MockNetwork: pass async def switch_lagging_interface(self): pass + def blockchain(self) -> Blockchain: + return self.interface.blockchain + def get_local_height(self) -> int: + return self.blockchain().height() # regtest chain: @@ -106,11 +116,11 @@ BLOCK_HEADERS = { } _active_server_sessions = set() -def _get_active_server_session() -> 'ServerSession': +def _get_active_server_session() -> 'ToyServerSession': assert 1 == len(_active_server_sessions), len(_active_server_sessions) return list(_active_server_sessions)[0] -class ServerSession(aiorpcx.RPCSession, Logger): +class ToyServerSession(aiorpcx.RPCSession, Logger): def __init__(self, *args, **kwargs): aiorpcx.RPCSession.__init__(self, *args, **kwargs) @@ -120,6 +130,12 @@ class ServerSession(aiorpcx.RPCSession, Logger): self.txs = { "bdae818ad3c1f261317738ae9284159bf54874356f186dbc7afd631dc1527fcb": bfh("020000000001010000000000000000000000000000000000000000000000000000000000000000ffffffff025100ffffffff0200f2052a010000001600140297bde2689a3c79ffe050583b62f86f2d9dae540000000000000000266a24aa21a9ede2f61c3f71d1defd3fa999dfa36953755c690689799962b48bebd836974e8cf90120000000000000000000000000000000000000000000000000000000000000000000000000"), } # type: dict[str, bytes] + self.txid_to_block_height = collections.defaultdict(int) # type: dict[str, int] + self.subbed_headers = False + self.notified_height = None # type: Optional[int] + self.subbed_scripthashes = set() # type: set[str] + self.sh_to_funding_txids = collections.defaultdict(set) # type: dict[str, set[str]] + self.sh_to_spending_txids = collections.defaultdict(set) # type: dict[str, set[str]] self._method_counts = collections.defaultdict(int) # type: dict[str, int] _active_server_sessions.add(self) @@ -136,8 +152,11 @@ class ServerSession(aiorpcx.RPCSession, Logger): 'blockchain.headers.subscribe': self._handle_headers_subscribe, 'blockchain.block.header': self._handle_block_header, 'blockchain.block.headers': self._handle_block_headers, + 'blockchain.scripthash.subscribe': self._handle_scripthash_subscribe, + 'blockchain.scripthash.get_history': self._handle_scripthash_get_history, 'blockchain.transaction.get': self._handle_transaction_get, 'blockchain.transaction.broadcast': self._handle_transaction_broadcast, + 'blockchain.transaction.get_merkle': self._handle_transaction_get_merkle, 'server.ping': self._handle_ping, } handler = handlers.get(request.method) @@ -162,9 +181,13 @@ class ServerSession(aiorpcx.RPCSession, Logger): async def _handle_estimatefee(self, number, mode=None): return 1000 - async def _handle_headers_subscribe(self): + def _get_headersub_result(self): return {'hex': BLOCK_HEADERS[self.cur_height].hex(), 'height': self.cur_height} + async def _handle_headers_subscribe(self): + self.subbed_headers = True + return self._get_headersub_result() + async def _handle_block_header(self, height): return BLOCK_HEADERS[height].hex() @@ -186,10 +209,97 @@ class ServerSession(aiorpcx.RPCSession, Logger): raise RPCError(DAEMON_ERROR, f'daemon error: unknown txid={tx_hash}') return rawtx.hex() - async def _handle_transaction_broadcast(self, raw_tx: str): + async def _handle_transaction_get_merkle(self, tx_hash: str, height: int) -> dict: + # Fake stuff. Client will ignore it due to config.NETWORK_SKIPMERKLECHECK + return { + "merkle": + [ + "713d6c7e6ce7bbea708d61162231eaa8ecb31c4c5dd84f81c20409a90069cb24", + "03dbaec78d4a52fbaf3c7aa5d3fccd9d8654f323940716ddf5ee2e4bda458fde", + "e670224b23f156c27993ac3071940c0ff865b812e21e0a162fe7a005d6e57851", + "369a1619a67c3108a8850118602e3669455c70cdcdb89248b64cc6325575b885", + "4756688678644dcb27d62931f04013254a62aeee5dec139d1aac9f7b1f318112", + "7b97e73abc043836fd890555bfce54757d387943a6860e5450525e8e9ab46be5", + "61505055e8b639b7c64fd58bce6fc5c2378b92e025a02583303f69930091b1c3", + "27a654ff1895385ac14a574a0415d3bbba9ec23a8774f22ec20d53dd0b5386ff", + "5312ed87933075e60a9511857d23d460a085f3b6e9e5e565ad2443d223cfccdc", + "94f60b14a9f106440a197054936e6fb92abbd69d6059b38fdf79b33fc864fca0", + "2d64851151550e8c4d337f335ee28874401d55b358a66f1bafab2c3e9f48773d" + ], + "block_height": height, + "pos": 710, + } + + async def _handle_transaction_broadcast(self, raw_tx: str) -> str: tx = Transaction(raw_tx) - self.txs[tx.txid()] = bfh(raw_tx) - return tx.txid() + txid = tx.txid() + self.txs[txid] = bfh(raw_tx) + touched_sh = await self._process_added_tx(txid=txid) + if touched_sh: + await self._send_notifications(touched_sh=touched_sh) + return txid + + async def _process_added_tx(self, *, txid: str) -> set[str]: + """Returns touched scripthashes.""" + tx = Transaction(self.txs[txid]) + touched_sh = set() + # update sh_to_funding_txids + for txout in tx.outputs(): + sh = script_to_scripthash(txout.scriptpubkey) + self.sh_to_funding_txids[sh].add(txid) + touched_sh.add(sh) + # update sh_to_spending_txids + for txin in tx.inputs(): + if parent_tx_raw := self.txs.get(txin.prevout.txid.hex()): + parent_tx = Transaction(parent_tx_raw) + ptxout = parent_tx.outputs()[txin.prevout.out_idx] + sh = script_to_scripthash(ptxout.scriptpubkey) + self.sh_to_spending_txids[sh].add(txid) + touched_sh.add(sh) + return touched_sh + + async def _handle_scripthash_subscribe(self, sh: str) -> Optional[str]: + self.subbed_scripthashes.add(sh) + hist = self._calc_sh_history(sh) + return history_status(hist) + + async def _handle_scripthash_get_history(self, sh: str) -> Sequence[dict]: + hist_tuples = self._calc_sh_history(sh) + hist_dicts = [{"height": height, "tx_hash": txid} for (txid, height) in hist_tuples] + for hist_dict in hist_dicts: # add "fee" key for mempool txs + if hist_dict["height"] in (0, -1,): + hist_dict["fee"] = 0 + return hist_dicts + + def _calc_sh_history(self, sh: str) -> Sequence[tuple[str, int]]: + txids = self.sh_to_funding_txids[sh] | self.sh_to_spending_txids[sh] + hist = [] + for txid in txids: + bh = self.txid_to_block_height[txid] + hist.append((txid, bh)) + hist.sort(key=lambda x: x[1]) # FIXME put mempool txs last + return hist + + async def _send_notifications(self, *, touched_sh: Iterable[str], height_changed: bool = False) -> None: + if height_changed and self.subbed_headers and self.notified_height != self.cur_height: + self.notified_height = self.cur_height + args = (self._get_headersub_result(),) + await self.send_notification('blockchain.headers.subscribe', args) + touched_sh = set(sh for sh in touched_sh if sh in self.subbed_scripthashes) + for sh in touched_sh: + hist = self._calc_sh_history(sh) + args = (sh, history_status(hist)) + await self.send_notification("blockchain.scripthash.subscribe", args) + + async def mine_block(self, *, txids_mined: Iterable[str] = None): + if txids_mined is None: + txids_mined = [] + self.cur_height += 1 + touched_sh = set() + for txid in txids_mined: + self.txid_to_block_height[txid] = self.cur_height + touched_sh |= await self._process_added_tx(txid=txid) + await self._send_notifications(touched_sh=touched_sh, height_changed=True) class TestInterface(ElectrumTestCase): @@ -198,6 +308,7 @@ class TestInterface(ElectrumTestCase): def setUp(self): super().setUp() self.config = SimpleConfig({'electrum_path': self.electrum_path}) + self.config.NETWORK_SKIPMERKLECHECK = True self._orig_WAIT_FOR_BUFFER_GROWTH_SECONDS = PaddedRSTransport.WAIT_FOR_BUFFER_GROWTH_SECONDS PaddedRSTransport.WAIT_FOR_BUFFER_GROWTH_SECONDS = 0 @@ -207,7 +318,7 @@ class TestInterface(ElectrumTestCase): async def asyncSetUp(self): await super().asyncSetUp() - self._server: asyncio.base_events.Server = await aiorpcx.serve_rs(ServerSession, "127.0.0.1") + self._server: asyncio.base_events.Server = await aiorpcx.serve_rs(ToyServerSession, "127.0.0.1") server_socket_addr = self._server.sockets[0].getsockname() self._server_port = server_socket_addr[1] self.network = MockNetwork(config=self.config) @@ -255,3 +366,35 @@ class TestInterface(ElectrumTestCase): rawtx2 = await interface.get_transaction(tx.txid()) self.assertEqual(rawtx1, rawtx2) self.assertEqual(_get_active_server_session()._method_counts["blockchain.transaction.get"], 0) + + async def test_dont_request_gethistory_if_status_change_results_from_mempool_txs_simply_getting_mined(self): + """After a new block is mined, we recv "blockchain.scripthash.subscribe" notifs. + We opportunistically guess the scripthash status changed purely because touching mempool txs just got mined. + If the guess is correct, we won't call the "blockchain.scripthash.get_history" RPC. + """ + interface = await self._start_iface_and_wait_for_sync() + w1 = restore_wallet_from_text__for_unittest("9dk", path=None, config=self.config)['wallet'] # type: Abstract_Wallet + w1.start_network(self.network) + await w1.up_to_date_changed_event.wait() + self.assertEqual(_get_active_server_session()._method_counts["blockchain.scripthash.get_history"], 0) + # fund w1 (in mempool) + funding_tx = "01000000000101e855888b77b1688d08985b863bfe85b354049b4eba923db9b5cf37089975d5d10000000000fdffffff0280969800000000001600140297bde2689a3c79ffe050583b62f86f2d9dae5460abe9000000000016001472df47551b6e7e0c8428814d2e572bc5ac773dda024730440220383efa2f0f5b87f8ce5d6b6eaf48cba03bf522b23fbb23b2ac54ff9d9a8f6a8802206f67d1f909f3c7a22ac0308ac4c19853ffca3a9317e1d7e0c88cc3a86853aaac0121035061949222555a0df490978fe6e7ebbaa96332ecb5c266918fd800c0eef736e7358d1400" + funding_txid = await _get_active_server_session()._handle_transaction_broadcast(funding_tx) + await w1.up_to_date_changed_event.wait() + while not w1.is_up_to_date(): + await w1.up_to_date_changed_event.wait() + self.assertEqual(_get_active_server_session()._method_counts["blockchain.scripthash.get_history"], 1) + self.assertEqual( + w1.adb.get_address_history("bcrt1qq2tmmcngng78nllq2pvrkchcdukemtj5jnxz44"), + {funding_txid: 0}) + # mine funding tx + await _get_active_server_session().mine_block(txids_mined=[funding_txid]) + await w1.up_to_date_changed_event.wait() + while not w1.is_up_to_date(): + await w1.up_to_date_changed_event.wait() + # see if we managed to guess new history, and hence did not need to call get_history RPC + self.assertEqual(_get_active_server_session()._method_counts["blockchain.scripthash.get_history"], 1) + self.assertEqual( + w1.adb.get_address_history("bcrt1qq2tmmcngng78nllq2pvrkchcdukemtj5jnxz44"), + {funding_txid: 7}) +