From 64a160027ad19b5caf3716e96027695332fc1a0e Mon Sep 17 00:00:00 2001 From: Sander van Grieken Date: Wed, 23 Apr 2025 16:09:31 +0200 Subject: [PATCH] imports, whitespace, type hints --- electrum/mnemonic.py | 2 +- electrum/mpp_split.py | 1 - electrum/network.py | 31 +++++++++------------- electrum/trampoline.py | 20 +++++++++----- electrum/transaction.py | 59 ++++++++++++++++++++++++----------------- electrum/x509.py | 2 +- 6 files changed, 63 insertions(+), 52 deletions(-) diff --git a/electrum/mnemonic.py b/electrum/mnemonic.py index 91546c2c9..3752119b8 100644 --- a/electrum/mnemonic.py +++ b/electrum/mnemonic.py @@ -22,7 +22,7 @@ # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import os + import math import hashlib import unicodedata diff --git a/electrum/mpp_split.py b/electrum/mpp_split.py index 520d7f0a1..084ad612a 100644 --- a/electrum/mpp_split.py +++ b/electrum/mpp_split.py @@ -1,7 +1,6 @@ import random import math from typing import List, Tuple, Dict, NamedTuple -from collections import defaultdict from .lnutil import NoPathFound diff --git a/electrum/network.py b/electrum/network.py index d871bc42b..4dae51352 100644 --- a/electrum/network.py +++ b/electrum/network.py @@ -22,22 +22,16 @@ # SOFTWARE. import asyncio import time -import queue import os import random import re from collections import defaultdict import threading -import socket import json -import sys from typing import ( NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING, Iterable, Set, Any, TypeVar, Callable ) -import traceback -import concurrent -from concurrent import futures import copy import functools from enum import IntEnum @@ -48,20 +42,20 @@ from aiorpcx import ignore_after, NetAddress from aiohttp import ClientResponse from . import util -from .util import (log_exceptions, ignore_exceptions, OldTaskGroup, - bfh, make_aiohttp_session, send_exception_to_crash_reporter, - is_hash256_str, is_non_negative_integer, MyEncoder, NetworkRetryManager, - error_text_str_to_safe_str, detect_tor_socks_proxy) -from .bitcoin import COIN, DummyAddress, DummyAddressUsedInTxException +from .util import ( + log_exceptions, ignore_exceptions, OldTaskGroup, make_aiohttp_session, send_exception_to_crash_reporter, MyEncoder, + NetworkRetryManager, error_text_str_to_safe_str, detect_tor_socks_proxy +) +from .bitcoin import DummyAddress, DummyAddressUsedInTxException from . import constants from . import blockchain -from . import bitcoin from . import dns_hacks from .transaction import Transaction -from .blockchain import Blockchain, HEADER_SIZE -from .interface import (Interface, PREFERRED_NETWORK_PROTOCOL, - RequestTimedOut, NetworkTimeout, BUCKET_NAME_OF_ONION_SERVERS, - NetworkException, RequestCorrupted, ServerAddr) +from .blockchain import Blockchain +from .interface import ( + Interface, PREFERRED_NETWORK_PROTOCOL, RequestTimedOut, NetworkTimeout, BUCKET_NAME_OF_ONION_SERVERS, + NetworkException, RequestCorrupted, ServerAddr +) from .version import PROTOCOL_VERSION from .i18n import _ from .logging import get_logger, Logger @@ -70,11 +64,9 @@ from .fee_policy import FeeHistogram, FeeTimeEstimates, FEE_ETA_TARGETS if TYPE_CHECKING: from collections.abc import Coroutine - from .channel_db import ChannelDB from .lnrouter import LNPathFinder from .lnworker import LNGossip - #from .lnwatcher import WatchTower from .daemon import Daemon from .simple_config import SimpleConfig @@ -554,8 +546,10 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): async def get_banner(): self.banner = await interface.get_server_banner() util.trigger_callback('banner', self.banner) + async def get_donation_address(): self.donation_address = await interface.get_donation_address() + async def get_server_peers(): server_peers = await session.send_request('server.peers.subscribe') random.shuffle(server_peers) @@ -564,6 +558,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): # note that 'parse_servers' also validates the data (which is untrusted input!) self.server_peers = parse_servers(server_peers) util.trigger_callback('servers', self.get_servers()) + async def get_relay_fee(): self.relay_fee = await interface.get_relay_fee() diff --git a/electrum/trampoline.py b/electrum/trampoline.py index 404d32446..32a5f4a06 100644 --- a/electrum/trampoline.py +++ b/electrum/trampoline.py @@ -1,13 +1,13 @@ import io import os import random -from typing import Mapping, DefaultDict, Tuple, Optional, Dict, List, Iterable, Sequence, Set, Any, \ - MutableSequence +from typing import Mapping, Tuple, Optional, List, Iterable, Sequence, Set, Any from .lnutil import LnFeatures, PaymentFeeBudget, FeeBudgetExceeded -from .lnonion import calc_hops_data_for_payment, new_onion_packet, OnionPacket, \ - TRAMPOLINE_HOPS_DATA_SIZE, PER_HOP_HMAC_SIZE -from .lnrouter import RouteEdge, TrampolineEdge, LNPaymentRoute, is_route_within_budget, LNPaymentTRoute +from .lnonion import ( + calc_hops_data_for_payment, new_onion_packet, OnionPacket, TRAMPOLINE_HOPS_DATA_SIZE, PER_HOP_HMAC_SIZE +) +from .lnrouter import TrampolineEdge, is_route_within_budget, LNPaymentTRoute from .lnutil import NoPathFound from .lntransport import LNPeerAddr from . import constants @@ -38,6 +38,7 @@ TRAMPOLINE_NODES_SIGNET = { _TRAMPOLINE_NODES_UNITTESTS = {} # used in unit tests + def hardcoded_trampoline_nodes() -> Mapping[str, LNPeerAddr]: if _TRAMPOLINE_NODES_UNITTESTS: return _TRAMPOLINE_NODES_UNITTESTS @@ -52,12 +53,15 @@ def hardcoded_trampoline_nodes() -> Mapping[str, LNPeerAddr]: else: return {} + def trampolines_by_id(): return dict([(x.pubkey, x) for x in hardcoded_trampoline_nodes().values()]) + def is_hardcoded_trampoline(node_id: bytes) -> bool: return node_id in trampolines_by_id() + def encode_routing_info(r_tags: Sequence[Sequence[Sequence[Any]]]) -> List[bytes]: routes = [] for route in r_tags: @@ -126,6 +130,8 @@ def is_legacy_relay(invoice_features, r_tags) -> Tuple[bool, Set[bytes]]: PLACEHOLDER_FEE = None + + def _extend_trampoline_route( route: List[TrampolineEdge], *, @@ -301,7 +307,7 @@ def create_trampoline_onion( payload.pop('short_channel_id') next_edge = route[i+1] assert next_edge.is_trampoline() - hops_data[i].payload["outgoing_node_id"] = {"outgoing_node_id":next_edge.node_id} + hops_data[i].payload["outgoing_node_id"] = {"outgoing_node_id": next_edge.node_id} # only for final if i == num_hops - 1: payload["payment_data"] = { @@ -310,7 +316,7 @@ def create_trampoline_onion( } # legacy if i == num_hops - 2 and route_edge.invoice_features: - payload["invoice_features"] = {"invoice_features":route_edge.invoice_features} + payload["invoice_features"] = {"invoice_features": route_edge.invoice_features} routing_info_payload_index = i payload["payment_data"] = { "payment_secret": payment_secret, diff --git a/electrum/transaction.py b/electrum/transaction.py index f4959a0ab..454d59ad8 100644 --- a/electrum/transaction.py +++ b/electrum/transaction.py @@ -23,17 +23,14 @@ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. - - # Note: The deserialization code originally comes from ABE. import struct -import traceback -import sys import io import base64 -from typing import (Sequence, Union, NamedTuple, Tuple, Optional, Iterable, - Callable, List, Dict, Set, TYPE_CHECKING, Mapping) +from typing import ( + Sequence, Union, NamedTuple, Tuple, Optional, Iterable, Callable, List, Dict, Set, TYPE_CHECKING, Mapping +) from collections import defaultdict from enum import IntEnum import itertools @@ -43,22 +40,18 @@ import copy import electrum_ecc as ecc from electrum_ecc.util import bip340_tagged_hash -from . import bitcoin, constants, segwit_addr, bip32 +from . import bitcoin, bip32 from .bip32 import BIP32Node -from .i18n import _ -from .util import profiler, to_bytes, bfh, chunks, is_hex_str, parse_max_spend -from .bitcoin import (TYPE_ADDRESS, TYPE_SCRIPT, hash_160, - hash160_to_p2sh, hash160_to_p2pkh, hash_to_segwit_addr, - var_int, TOTAL_COIN_SUPPLY_LIMIT_IN_BTC, COIN, - opcodes, base_decode, - base_encode, construct_witness, construct_script, - taproot_tweak_seckey) +from .util import to_bytes, bfh, chunks, is_hex_str, parse_max_spend +from .bitcoin import ( + TYPE_ADDRESS, TYPE_SCRIPT, hash_160, hash160_to_p2sh, hash160_to_p2pkh, hash_to_segwit_addr, var_int, + TOTAL_COIN_SUPPLY_LIMIT_IN_BTC, COIN, opcodes, base_decode, base_encode, construct_witness, construct_script, + taproot_tweak_seckey +) from .crypto import sha256d, sha256 from .logging import get_logger from .util import ShortID, OldTaskGroup -from .bitcoin import DummyAddress from .descriptor import Descriptor, MissingSolutionPiece, create_dummy_descriptor_from_address -from .json_db import stored_in if TYPE_CHECKING: from .wallet import Abstract_Wallet @@ -659,7 +652,7 @@ class OPPushDataGeneric: class OPGeneric: - def __init__(self, matcher: Callable=None): + def __init__(self, matcher: Callable = None): if matcher is not None: self.matcher = matcher @@ -673,6 +666,7 @@ class OPGeneric: return isinstance(item, cls) \ or (isinstance(item, type) and issubclass(item, cls)) + OPPushDataPubkey = OPPushDataGeneric(lambda x: x in (33, 65)) OP_ANYSEGWIT_VERSION = OPGeneric(lambda x: x in list(range(opcodes.OP_1, opcodes.OP_16 + 1))) @@ -702,6 +696,7 @@ def check_scriptpubkey_template_and_dust(scriptpubkey, amount: Optional[int]): if amount < dust_limit: raise Exception(f'amount ({amount}) is below dust limit for scriptpubkey type ({dust_limit})') + def merge_duplicate_tx_outputs(outputs: Iterable['PartialTxOutput']) -> List['PartialTxOutput']: """Merges outputs that are paying to the same address by replacing them with a single larger output.""" output_dict = {} @@ -713,6 +708,7 @@ def merge_duplicate_tx_outputs(outputs: Iterable['PartialTxOutput']) -> List['Pa output_dict[output.scriptpubkey] = copy.copy(output) return list(output_dict.values()) + def match_script_against_template(script, template, debug=False) -> bool: """Returns whether 'script' matches 'template'.""" if script is None: @@ -744,6 +740,7 @@ def match_script_against_template(script, template, debug=False) -> bool: return False return True + def get_script_type_from_output_script(_bytes: bytes) -> Optional[str]: if _bytes is None: return None @@ -761,6 +758,7 @@ def get_script_type_from_output_script(_bytes: bytes) -> Optional[str]: return 'p2wsh' return None + def get_address_from_output_script(_bytes: bytes, *, net=None) -> Optional[str]: try: decoded = [x for x in script_GetOp(_bytes)] @@ -946,7 +944,7 @@ class Transaction: raise UnknownTxinType("cannot construct witness") @classmethod - def input_script(self, txin: TxInput, *, estimate_size=False) -> bytes: + def input_script(cls, txin: TxInput, *, estimate_size=False) -> bytes: if txin.script_sig is not None: return txin.script_sig if txin.is_coinbase_input(): @@ -1100,6 +1098,7 @@ class Transaction: num_tasks_total = 0 has_errored = False has_finished = False + async def add_info_to_txin(txin: TxInput): nonlocal num_tasks_done, has_errored progress_cb(TxinDataFetchProgress(num_tasks_done, num_tasks_total, has_errored, has_finished)) @@ -1113,6 +1112,7 @@ class Transaction: else: has_errored = True progress_cb(TxinDataFetchProgress(num_tasks_done, num_tasks_total, has_errored, has_finished)) + # schedule a network task for each txin try: async with OldTaskGroup() as group: @@ -1177,7 +1177,7 @@ class Transaction: @classmethod def estimated_input_weight(cls, txin: TxInput, is_segwit_tx: bool) -> int: - '''Return an estimate of serialized input weight in weight units.''' + """Return an estimate of serialized input weight in weight units.""" script_sig = cls.input_script(txin, estimate_size=True) input_size = len(txin.serialize_to_network(script_sig=script_sig)) @@ -2049,7 +2049,8 @@ class PartialTransaction(Transaction): if kt == PSBTGlobalType.UNSIGNED_TX: if tx is not None: raise SerializationError(f"duplicate key: {repr(kt)}") - if key: raise SerializationError(f"key for {repr(kt)} must be empty") + if key: + raise SerializationError(f"key for {repr(kt)} must be empty") unsigned_tx = Transaction(val.hex()) for txin in unsigned_tx.inputs(): if txin.script_sig or txin.witness: @@ -2092,7 +2093,8 @@ class PartialTransaction(Transaction): psbt_version = int.from_bytes(val, byteorder='little', signed=False) if psbt_version > 0: raise SerializationError(f"Only PSBTs with version 0 are supported. Found version: {psbt_version}") - if key: raise SerializationError(f"key for {repr(kt)} must be empty") + if key: + raise SerializationError(f"key for {repr(kt)} must be empty") else: full_key = PSBTSection.get_fullkey_from_keytype_and_key(kt, key) if full_key in tx._unknown: @@ -2119,8 +2121,15 @@ class PartialTransaction(Transaction): return tx @classmethod - def from_io(cls, inputs: Sequence[PartialTxInput], outputs: Sequence[PartialTxOutput], *, - locktime: int = None, version: int = None, BIP69_sort: bool = True): + def from_io( + cls, + inputs: Sequence[PartialTxInput], + outputs: Sequence[PartialTxOutput], + *, + locktime: int = None, + version: int = None, + BIP69_sort: bool = True + ) -> 'PartialTransaction': self = cls() self._inputs = list(inputs) self._outputs = list(outputs) @@ -2501,9 +2510,11 @@ class PartialTransaction(Transaction): await self.add_info_from_network(wallet.network) # log warning if PSBT_*_BIP32_DERIVATION fields cannot be filled with full path due to missing info from .keystore import Xpub + def is_ks_missing_info(ks): return (isinstance(ks, Xpub) and (ks.get_root_fingerprint() is None or ks.get_derivation_prefix() is None)) + if any([is_ks_missing_info(ks) for ks in wallet.get_keystores()]): _logger.warning('PSBT was requested to be filled with full bip32 paths but ' 'some keystores lacked either the derivation prefix or the root fingerprint') diff --git a/electrum/x509.py b/electrum/x509.py index 07450bb2b..c7909082c 100644 --- a/electrum/x509.py +++ b/electrum/x509.py @@ -26,7 +26,6 @@ import hashlib import time -from . import util from .util import profiler, timestamp_to_datetime from .logging import get_logger @@ -141,6 +140,7 @@ class ASN1_Node(bytes): raise TypeError('Can only open constructed types.', hex(self[ixs])) return self.get_node(ixf) + @staticmethod def is_child_of(node1, node2): ixs, ixf, ixl = node1 jxs, jxf, jxl = node2