diff --git a/electrum/interface.py b/electrum/interface.py index 2411a06b0..36cfb13c4 100644 --- a/electrum/interface.py +++ b/electrum/interface.py @@ -64,6 +64,7 @@ from .i18n import _ from .logging import Logger from .transaction import Transaction from .fee_policy import FEE_ETA_TARGETS +from .lrucache import LRUCache if TYPE_CHECKING: from .network import Network @@ -558,6 +559,7 @@ class Interface(Logger): self.tip = 0 self._headers_cache = {} # type: Dict[int, bytes] + self._rawtx_cache = LRUCache(maxsize=20) # type: LRUCache[str, bytes] # txid->rawtx self.fee_estimates_eta = {} # type: Dict[int, int] @@ -1318,6 +1320,8 @@ class Interface(Logger): async def get_transaction(self, tx_hash: str, *, timeout=None) -> str: if not is_hash256_str(tx_hash): raise Exception(f"{repr(tx_hash)} is not a txid") + if rawtx_bytes := self._rawtx_cache.get(tx_hash): + return rawtx_bytes.hex() raw = await self.session.send_request('blockchain.transaction.get', [tx_hash], timeout=timeout) # validate response if not is_hex_str(raw): @@ -1329,16 +1333,21 @@ class Interface(Logger): raise RequestCorrupted(f"cannot deserialize received transaction (txid {tx_hash})") from e if tx.txid() != tx_hash: raise RequestCorrupted(f"received tx does not match expected txid {tx_hash} (got {tx.txid()})") + self._rawtx_cache[tx_hash] = bytes.fromhex(raw) return raw async def broadcast_transaction(self, tx: 'Transaction', *, timeout=None) -> None: """caller should handle TxBroadcastError and RequestTimedOut""" + txid_calc = tx.txid() + assert txid_calc is not None + rawtx = tx.serialize() + assert is_hex_str(rawtx) if timeout is None: timeout = self.network.get_network_timeout_seconds(NetworkTimeout.Urgent) if any(DummyAddress.is_dummy_address(txout.address) for txout in tx.outputs()): raise DummyAddressUsedInTxException("tried to broadcast tx with dummy address!") try: - out = await self.session.send_request('blockchain.transaction.broadcast', [tx.serialize()], timeout=timeout) + out = await self.session.send_request('blockchain.transaction.broadcast', [rawtx], timeout=timeout) # note: both 'out' and exception messages are untrusted input from the server except (RequestTimedOut, asyncio.CancelledError, asyncio.TimeoutError): raise # pass-through @@ -1349,10 +1358,14 @@ class Interface(Logger): self.logger.info(f"broadcast_transaction error2 [DO NOT TRUST THIS MESSAGE]: {error_text_str_to_safe_str(repr(e))}. tx={str(tx)}") send_exception_to_crash_reporter(e) raise TxBroadcastUnknownError() from e - if out != tx.txid(): + if out != txid_calc: self.logger.info(f"unexpected txid for broadcast_transaction [DO NOT TRUST THIS MESSAGE]: " - f"{error_text_str_to_safe_str(out)} != {tx.txid()}. tx={str(tx)}") + f"{error_text_str_to_safe_str(out)} != {txid_calc}. tx={str(tx)}") raise TxBroadcastHashMismatch(_("Server returned unexpected transaction ID.")) + # broadcast succeeded. + # We now cache the rawtx, for *this interface only*. The tx likely touches some ismine addresses, affecting + # the status of a scripthash we are subscribed to. Caching here will save a future get_transaction RPC. + self._rawtx_cache[txid_calc] = bytes.fromhex(rawtx) async def get_history_for_scripthash(self, sh: str) -> List[dict]: if not is_hash256_str(sh): diff --git a/tests/test_interface.py b/tests/test_interface.py index 47310853d..aa54f9bdd 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -1,4 +1,5 @@ import asyncio +import collections import aiorpcx from aiorpcx import RPCError @@ -117,6 +118,7 @@ class ServerSession(aiorpcx.RPCSession, Logger): self.txs = { "bdae818ad3c1f261317738ae9284159bf54874356f186dbc7afd631dc1527fcb": bfh("020000000001010000000000000000000000000000000000000000000000000000000000000000ffffffff025100ffffffff0200f2052a010000001600140297bde2689a3c79ffe050583b62f86f2d9dae540000000000000000266a24aa21a9ede2f61c3f71d1defd3fa999dfa36953755c690689799962b48bebd836974e8cf90120000000000000000000000000000000000000000000000000000000000000000000000000"), } # type: dict[str, bytes] + self._method_counts = collections.defaultdict(int) # type: dict[str, int] _active_server_sessions.add(self) async def connection_lost(self): @@ -136,6 +138,7 @@ class ServerSession(aiorpcx.RPCSession, Logger): 'server.ping': self._handle_ping, } handler = handlers.get(request.method) + self._method_counts[request.method] += 1 coro = aiorpcx.handler_invocation(handler, request)() return await coro @@ -220,11 +223,18 @@ class TestInterface(ElectrumTestCase): # try requesting known tx: rawtx = await interface.get_transaction("bdae818ad3c1f261317738ae9284159bf54874356f186dbc7afd631dc1527fcb") self.assertEqual(rawtx, _get_active_server_session().txs["bdae818ad3c1f261317738ae9284159bf54874356f186dbc7afd631dc1527fcb"].hex()) + self.assertEqual(_get_active_server_session()._method_counts["blockchain.transaction.get"], 2) async def test_transaction_broadcast(self): interface = await self._start_iface_and_wait_for_sync() rawtx1 = "020000000001010000000000000000000000000000000000000000000000000000000000000000ffffffff025200ffffffff0200f2052a010000001600140297bde2689a3c79ffe050583b62f86f2d9dae540000000000000000266a24aa21a9ede2f61c3f71d1defd3fa999dfa36953755c690689799962b48bebd836974e8cf90120000000000000000000000000000000000000000000000000000000000000000000000000" tx = Transaction(rawtx1) + # broadcast await interface.broadcast_transaction(tx) + self.assertEqual(bfh(rawtx1), _get_active_server_session().txs.get(tx.txid())) + # now request tx. + # as we just broadcast this same tx, this will hit the client iface cache, and won't call the server. + self.assertEqual(_get_active_server_session()._method_counts["blockchain.transaction.get"], 0) rawtx2 = await interface.get_transaction(tx.txid()) self.assertEqual(rawtx1, rawtx2) + self.assertEqual(_get_active_server_session()._method_counts["blockchain.transaction.get"], 0)