1
0

Merge pull request #9833 from f321x/use_asyncio_dnspython_methods

dns: use async dnspython interface
This commit is contained in:
ThomasV
2025-05-20 08:56:47 +02:00
committed by GitHub
10 changed files with 68 additions and 70 deletions

View File

@@ -854,10 +854,10 @@ class Commands(Logger):
"bad_keys": len(bad_inputs),
}
def _resolver(self, x, wallet: Abstract_Wallet):
async def _resolver(self, x, wallet: Abstract_Wallet):
if x is None:
return None
out = wallet.contacts.resolve(x)
out = await wallet.contacts.resolve(x)
if out.get('type') == 'openalias' and self.nocheck is False and out.get('validated') is False:
raise UserFacingException(f"cannot verify alias: {x}")
return out['address']
@@ -980,11 +980,13 @@ class Commands(Logger):
fee_policy = self._get_fee_policy(fee, feerate)
domain_addr = from_addr.split(',') if from_addr else None
domain_coins = from_coins.split(',') if from_coins else None
change_addr = self._resolver(change_addr, wallet)
domain_addr = None if domain_addr is None else map(self._resolver, domain_addr, repeat(wallet))
change_addr = await self._resolver(change_addr, wallet)
if domain_addr is not None:
resolvers = [self._resolver(addr, wallet) for addr in domain_addr]
domain_addr = await asyncio.gather(*resolvers)
final_outputs = []
for address, amount in outputs:
address = self._resolver(address, wallet)
address = await self._resolver(address, wallet)
amount_sat = satoshis_or_max(amount)
final_outputs.append(PartialTxOutput.from_address_and_value(address, amount_sat))
coins = wallet.get_spendable_coins(domain_addr)
@@ -1128,7 +1130,7 @@ class Commands(Logger):
arg:str:key:the alias to be retrieved
"""
return wallet.contacts.resolve(key)
return await wallet.contacts.resolve(key)
@command('w')
async def searchcontacts(self, query, wallet: Abstract_Wallet = None):

View File

@@ -22,16 +22,15 @@
# SOFTWARE.
import re
from typing import Optional, Tuple, Dict, Any, TYPE_CHECKING
import asyncio
import dns
import threading
from dns.exception import DNSException
from . import bitcoin
from . import dnssec
from .util import read_json_file, write_json_file, to_string, is_valid_email
from .logging import Logger, get_logger
from .util import trigger_callback
from .util import trigger_callback, get_asyncio_loop
if TYPE_CHECKING:
from .wallet_db import WalletDB
@@ -85,7 +84,7 @@ class Contacts(dict, Logger):
return res
return None
def resolve(self, k):
async def resolve(self, k) -> dict:
if bitcoin.is_address(k):
return {
'address': k,
@@ -99,13 +98,13 @@ class Contacts(dict, Logger):
'address': address,
'type': 'contact'
}
if openalias := self.resolve_openalias(k):
if openalias := await self.resolve_openalias(k):
return openalias
raise AliasNotFoundException("Invalid Bitcoin address or alias", k)
@classmethod
def resolve_openalias(cls, url: str) -> Dict[str, Any]:
out = cls._resolve_openalias(url)
async def resolve_openalias(cls, url: str) -> Dict[str, Any]:
out = await cls._resolve_openalias(url)
if out:
address, name, validated = out
return {
@@ -132,19 +131,17 @@ class Contacts(dict, Logger):
alias = config.OPENALIAS_ID
if alias:
alias = str(alias)
def f():
self.alias_info = self._resolve_openalias(alias)
async def f():
self.alias_info = await self._resolve_openalias(alias)
trigger_callback('alias_received')
t = threading.Thread(target=f)
t.daemon = True
t.start()
asyncio.run_coroutine_threadsafe(f(), get_asyncio_loop())
@classmethod
def _resolve_openalias(cls, url: str) -> Optional[Tuple[str, str, bool]]:
async def _resolve_openalias(cls, url: str) -> Optional[Tuple[str, str, bool]]:
# support email-style addresses, per the OA standard
url = url.replace('@', '.')
try:
records, validated = dnssec.query(url, dns.rdatatype.TXT)
records, validated = await dnssec.query(url, dns.rdatatype.TXT)
except DNSException as e:
_logger.info(f'Error resolving openalias: {repr(e)}')
return None
@@ -161,7 +158,6 @@ class Contacts(dict, Logger):
if not address:
continue
return address, name, validated
return None
return None
@staticmethod

View File

@@ -7,18 +7,16 @@ import socket
import concurrent
from concurrent import futures
import ipaddress
from typing import Optional
import asyncio
import dns
import dns.resolver
import dns.asyncresolver
from .logging import get_logger
from .util import get_asyncio_loop
_logger = get_logger(__name__)
_dns_threads_executor = None # type: Optional[concurrent.futures.Executor]
def configure_dns_resolver() -> None:
# Store this somewhere so we can un-monkey-patch:
@@ -38,16 +36,11 @@ def configure_dns_resolver() -> None:
def _prepare_windows_dns_hack():
# enable dns cache
resolver = dns.resolver.get_default_resolver()
resolver = dns.asyncresolver.get_default_resolver()
if resolver.cache is None:
resolver.cache = dns.resolver.Cache()
# ensure overall timeout for requests is long enough
resolver.lifetime = max(resolver.lifetime or 1, 30.0)
# prepare threads
global _dns_threads_executor
if _dns_threads_executor is None:
_dns_threads_executor = concurrent.futures.ThreadPoolExecutor(max_workers=20,
thread_name_prefix='dns_resolver')
def _is_force_system_dns_for_host(host: str) -> bool:
@@ -69,8 +62,15 @@ def _fast_getaddrinfo(host, *args, **kwargs):
addrs = []
expected_errors = (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer,
concurrent.futures.CancelledError, concurrent.futures.TimeoutError)
ipv6_fut = _dns_threads_executor.submit(dns.resolver.resolve, host, dns.rdatatype.AAAA)
ipv4_fut = _dns_threads_executor.submit(dns.resolver.resolve, host, dns.rdatatype.A)
loop = get_asyncio_loop()
ipv6_fut = asyncio.run_coroutine_threadsafe(
dns.asyncresolver.resolve(host, dns.rdatatype.AAAA),
loop,
)
ipv4_fut = asyncio.run_coroutine_threadsafe(
dns.asyncresolver.resolve(host, dns.rdatatype.A),
loop,
)
# try IPv6
try:
answers = ipv6_fut.result()

View File

@@ -33,10 +33,10 @@
import dns
import dns.name
import dns.query
import dns.asyncquery
import dns.dnssec
import dns.message
import dns.resolver
import dns.asyncresolver
import dns.rdatatype
import dns.rdtypes.ANY.NS
import dns.rdtypes.ANY.CNAME
@@ -53,6 +53,7 @@ import dns.rdtypes.IN.A
import dns.rdtypes.IN.AAAA
from .logging import get_logger
from typing import Tuple
_logger = get_logger(__name__)
@@ -67,9 +68,9 @@ trust_anchors = [
]
def _check_query(ns, sub, _type, keys):
async def _check_query(ns, sub, _type, keys) -> dns.rrset.RRset:
q = dns.message.make_query(sub, _type, want_dnssec=True)
response = dns.query.tcp(q, ns, timeout=5)
response = await dns.asyncquery.tcp(q, ns, timeout=5)
assert response.rcode() == 0, 'No answer'
answer = response.answer
assert len(answer) != 0, ('No DNS record found', sub, _type)
@@ -86,13 +87,13 @@ def _check_query(ns, sub, _type, keys):
return rrset
def _get_and_validate(ns, url, _type):
async def _get_and_validate(ns, url, _type) -> dns.rrset.RRset:
# get trusted root key
root_rrset = None
for dnskey_rr in trust_anchors:
try:
# Check if there is a valid signature for the root dnskey
root_rrset = _check_query(ns, '', dns.rdatatype.DNSKEY, {dns.name.root: dnskey_rr})
root_rrset = await _check_query(ns, '', dns.rdatatype.DNSKEY, {dns.name.root: dnskey_rr})
break
except dns.dnssec.ValidationFailure:
# It's OK as long as one key validates
@@ -107,16 +108,16 @@ def _get_and_validate(ns, url, _type):
name = dns.name.from_text(sub)
# If server is authoritative, don't fetch DNSKEY
query = dns.message.make_query(sub, dns.rdatatype.NS)
response = dns.query.udp(query, ns, 3)
response = await dns.asyncquery.udp(query, ns, 3)
assert response.rcode() == dns.rcode.NOERROR, "query error"
rrset = response.authority[0] if len(response.authority) > 0 else response.answer[0]
rr = rrset[0]
if rr.rdtype == dns.rdatatype.SOA:
continue
# get DNSKEY (self-signed)
rrset = _check_query(ns, sub, dns.rdatatype.DNSKEY, None)
rrset = await _check_query(ns, sub, dns.rdatatype.DNSKEY, None)
# get DS (signed by parent)
ds_rrset = _check_query(ns, sub, dns.rdatatype.DS, keys)
ds_rrset = await _check_query(ns, sub, dns.rdatatype.DS, keys)
# verify that a signed DS validates DNSKEY
for ds in ds_rrset:
for dnskey in rrset:
@@ -132,20 +133,20 @@ def _get_and_validate(ns, url, _type):
# set key for next iteration
keys = {name: rrset}
# get TXT record (signed by zone)
rrset = _check_query(ns, url, _type, keys)
rrset = await _check_query(ns, url, _type, keys)
return rrset
def query(url, rtype):
async def query(url, rtype) -> Tuple[dns.rrset.RRset, bool]:
# FIXME this method is not using the network proxy. (although the proxy might not support UDP?)
# 8.8.8.8 is Google's public DNS server
nameservers = ['8.8.8.8']
ns = nameservers[0]
try:
out = _get_and_validate(ns, url, rtype)
out = await _get_and_validate(ns, url, rtype)
validated = True
except Exception as e:
_logger.info(f"DNSSEC error: {repr(e)}")
out = dns.resolver.resolve(url, rtype)
out = await dns.asyncresolver.resolve(url, rtype)
validated = False
return out, validated

View File

@@ -17,6 +17,7 @@ from electrum.lnutil import format_short_channel_id
from electrum.bitcoin import COIN, address_to_script
from electrum.paymentrequest import PaymentRequest
from electrum.payment_identifier import PaymentIdentifier, PaymentIdentifierState, PaymentIdentifierType
from electrum.network import Network
from .qetypes import QEAmount
from .qewallet import QEWallet
@@ -523,7 +524,7 @@ class QEInvoiceParser(QEInvoice):
def _bip70_payment_request_resolved(self, pr: 'PaymentRequest'):
self._logger.debug('resolved payment request')
if pr.verify():
if Network.run_from_another_thread(pr.verify()):
invoice = Invoice.from_bip70_payreq(pr, height=0)
if self._wallet.wallet.get_invoice_status(invoice) == PR_PAID:
self.validationError.emit('unknown', _('Invoice already paid'))

View File

@@ -1646,7 +1646,7 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger, QtEventListener):
grid.addWidget(QLabel(format_time(invoice.exp + invoice.time)), 4, 1)
if invoice.bip70:
pr = paymentrequest.PaymentRequest(bytes.fromhex(invoice.bip70))
pr.verify()
Network.run_from_another_thread(pr.verify())
grid.addWidget(QLabel(_("Requestor") + ':'), 5, 0)
grid.addWidget(QLabel(pr.get_requestor()), 5, 1)
grid.addWidget(QLabel(_("Signature") + ':'), 6, 0)

View File

@@ -22,7 +22,7 @@ import urllib.parse
import itertools
import aiohttp
import dns.resolver
import dns.asyncresolver
import dns.exception
from aiorpcx import run_in_thread, NetAddress, ignore_after
@@ -427,10 +427,9 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
return [random.choice(fallback_list)]
# last resort: try dns seeds (BOLT-10)
return await run_in_thread(self._get_peers_from_dns_seeds)
return await self._get_peers_from_dns_seeds()
def _get_peers_from_dns_seeds(self) -> Sequence[LNPeerAddr]:
# NOTE: potentially long blocking call, do not run directly on asyncio event loop.
async def _get_peers_from_dns_seeds(self) -> Sequence[LNPeerAddr]:
# Return several peers to reduce the number of dns queries.
if not constants.net.LN_DNS_SEEDS:
return []
@@ -439,7 +438,7 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
try:
# note: this might block for several seconds
# this will include bech32-encoded-pubkeys and ports
srv_answers = resolve_dns_srv('r{}.{}'.format(
srv_answers = await resolve_dns_srv('r{}.{}'.format(
constants.net.LN_REALM_BYTE, dns_seed))
except dns.exception.DNSException as e:
self.logger.info(f'failed querying (1) dns seed "{dns_seed}" for ln peers: {repr(e)}')
@@ -451,8 +450,8 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
peers = []
for srv_ans in srv_answers:
try:
# note: this might block for several seconds
answers = dns.resolver.resolve(srv_ans['host'])
# note: this might take several seconds
answers = await dns.asyncresolver.resolve(srv_ans['host'])
except dns.exception.DNSException as e:
self.logger.info(f'failed querying (2) dns seed "{dns_seed}" for ln peers: {repr(e)}')
continue

View File

@@ -349,7 +349,7 @@ class PaymentIdentifier(Logger):
self.set_state(PaymentIdentifierState.NOT_FOUND)
elif self.bip70:
pr = await paymentrequest.get_payment_request(self.bip70)
if pr.verify():
if await pr.verify():
self.bip70_data = pr
self.set_state(PaymentIdentifierState.MERCHANT_NOTIFY)
else:
@@ -653,7 +653,7 @@ class PaymentIdentifier(Logger):
if parts and len(parts) > 0 and bitcoin.is_address(parts[0]):
return None
try:
data = self.contacts.resolve(key) # TODO: don't use contacts as delegate to resolve openalias, separate.
data = await self.contacts.resolve(key) # TODO: don't use contacts as delegate to resolve openalias, separate.
return data
except AliasNotFoundException as e:
self.logger.info(f'OpenAlias not found: {repr(e)}')

View File

@@ -40,7 +40,8 @@ except ImportError:
sys.exit("Error: could not find paymentrequest_pb2.py. Create it with 'contrib/generate_payreqpb2.sh'")
from . import bitcoin, constants, util, transaction, x509, rsakey
from .util import bfh, make_aiohttp_session, error_text_bytes_to_safe_str, get_running_loop
from .util import (bfh, make_aiohttp_session, error_text_bytes_to_safe_str, get_running_loop,
get_asyncio_loop)
from .invoices import Invoice, get_id_from_onchain_outputs
from .bitcoin import address_to_script
from .transaction import PartialTxOutput
@@ -104,10 +105,8 @@ async def get_payment_request(url: str) -> 'PaymentRequest':
data = None
error = f"Unknown scheme for payment request. URL: {url}"
pr = PaymentRequest(data, error=error)
loop = get_running_loop()
# do x509/dnssec verification now (in separate thread, to avoid blocking event loop).
# we still expect the caller to at least check pr.error!
await loop.run_in_executor(None, pr.verify)
# do x509/dnssec verification now. we still expect the caller to at least check pr.error!
await pr.verify()
return pr
@@ -153,7 +152,7 @@ class PaymentRequest:
self.memo = self.details.memo
self.payment_url = self.details.payment_url
def verify(self) -> bool:
async def verify(self) -> bool:
# FIXME: we should enforce that this method was called before we attempt payment
# note: this method might do network requests (at least for verify_dnssec)
if self._verified_success is True:
@@ -176,7 +175,7 @@ class PaymentRequest:
if pr.pki_type in ["x509+sha256", "x509+sha1"]:
return self.verify_x509(pr)
elif pr.pki_type in ["dnssec+btc", "dnssec+ecdsa"]:
return self.verify_dnssec(pr)
return await self.verify_dnssec(pr)
else:
self.error = "ERROR: Unsupported PKI Type for Message Signature"
return False
@@ -222,10 +221,10 @@ class PaymentRequest:
self._verified_success = True
return True
def verify_dnssec(self, pr):
async def verify_dnssec(self, pr):
sig = pr.signature
alias = pr.pki_data
info = Contacts.resolve_openalias(alias)
info: dict = await Contacts.resolve_openalias(alias)
if info.get('validated') is not True:
self.error = "Alias verification failed (DNSSEC)"
return False

View File

@@ -60,7 +60,7 @@ import aiohttp
from aiohttp_socks import ProxyConnector, ProxyType
import aiorpcx
import certifi
import dns.resolver
import dns.asyncresolver
from .i18n import _
from .logging import get_logger, Logger
@@ -1851,9 +1851,9 @@ def list_enabled_bits(x: int) -> Sequence[int]:
return tuple(i for i, b in enumerate(rev_bin) if b == '1')
def resolve_dns_srv(host: str):
async def resolve_dns_srv(host: str):
# FIXME this method is not using the network proxy. (although the proxy might not support UDP?)
srv_records = dns.resolver.resolve(host, 'SRV')
srv_records = await dns.asyncresolver.resolve(host, 'SRV')
# priority: prefer lower
# weight: tie breaker; prefer higher
srv_records = sorted(srv_records, key=lambda x: (x.priority, -x.weight))