1
0

Use async dnspython methods for openalias/dnssec

This commit is contained in:
f321x
2025-05-15 14:31:00 +02:00
parent 367dde4c9b
commit 61492d361e
7 changed files with 46 additions and 47 deletions

View File

@@ -841,10 +841,10 @@ class Commands(Logger):
"bad_keys": len(bad_inputs), "bad_keys": len(bad_inputs),
} }
def _resolver(self, x, wallet: Abstract_Wallet): async def _resolver(self, x, wallet: Abstract_Wallet):
if x is None: if x is None:
return None return None
out = wallet.contacts.resolve(x) out = await wallet.contacts.resolve(x)
if out.get('type') == 'openalias' and self.nocheck is False and out.get('validated') is False: if out.get('type') == 'openalias' and self.nocheck is False and out.get('validated') is False:
raise UserFacingException(f"cannot verify alias: {x}") raise UserFacingException(f"cannot verify alias: {x}")
return out['address'] return out['address']
@@ -967,11 +967,13 @@ class Commands(Logger):
fee_policy = self._get_fee_policy(fee, feerate) fee_policy = self._get_fee_policy(fee, feerate)
domain_addr = from_addr.split(',') if from_addr else None domain_addr = from_addr.split(',') if from_addr else None
domain_coins = from_coins.split(',') if from_coins else None domain_coins = from_coins.split(',') if from_coins else None
change_addr = self._resolver(change_addr, wallet) change_addr = await self._resolver(change_addr, wallet)
domain_addr = None if domain_addr is None else map(self._resolver, domain_addr, repeat(wallet)) if domain_addr is not None:
resolvers = [self._resolver(addr, wallet) for addr in domain_addr]
domain_addr = await asyncio.gather(*resolvers)
final_outputs = [] final_outputs = []
for address, amount in outputs: for address, amount in outputs:
address = self._resolver(address, wallet) address = await self._resolver(address, wallet)
amount_sat = satoshis_or_max(amount) amount_sat = satoshis_or_max(amount)
final_outputs.append(PartialTxOutput.from_address_and_value(address, amount_sat)) final_outputs.append(PartialTxOutput.from_address_and_value(address, amount_sat))
coins = wallet.get_spendable_coins(domain_addr) coins = wallet.get_spendable_coins(domain_addr)
@@ -1115,7 +1117,7 @@ class Commands(Logger):
arg:str:key:the alias to be retrieved arg:str:key:the alias to be retrieved
""" """
return wallet.contacts.resolve(key) return await wallet.contacts.resolve(key)
@command('w') @command('w')
async def searchcontacts(self, query, wallet: Abstract_Wallet = None): async def searchcontacts(self, query, wallet: Abstract_Wallet = None):

View File

@@ -22,16 +22,15 @@
# SOFTWARE. # SOFTWARE.
import re import re
from typing import Optional, Tuple, Dict, Any, TYPE_CHECKING from typing import Optional, Tuple, Dict, Any, TYPE_CHECKING
import asyncio
import dns import dns
import threading
from dns.exception import DNSException from dns.exception import DNSException
from . import bitcoin from . import bitcoin
from . import dnssec from . import dnssec
from .util import read_json_file, write_json_file, to_string, is_valid_email from .util import read_json_file, write_json_file, to_string, is_valid_email
from .logging import Logger, get_logger from .logging import Logger, get_logger
from .util import trigger_callback from .util import trigger_callback, get_asyncio_loop
if TYPE_CHECKING: if TYPE_CHECKING:
from .wallet_db import WalletDB from .wallet_db import WalletDB
@@ -85,7 +84,7 @@ class Contacts(dict, Logger):
return res return res
return None return None
def resolve(self, k): async def resolve(self, k) -> dict:
if bitcoin.is_address(k): if bitcoin.is_address(k):
return { return {
'address': k, 'address': k,
@@ -99,13 +98,13 @@ class Contacts(dict, Logger):
'address': address, 'address': address,
'type': 'contact' 'type': 'contact'
} }
if openalias := self.resolve_openalias(k): if openalias := await self.resolve_openalias(k):
return openalias return openalias
raise AliasNotFoundException("Invalid Bitcoin address or alias", k) raise AliasNotFoundException("Invalid Bitcoin address or alias", k)
@classmethod @classmethod
def resolve_openalias(cls, url: str) -> Dict[str, Any]: async def resolve_openalias(cls, url: str) -> Dict[str, Any]:
out = cls._resolve_openalias(url) out = await cls._resolve_openalias(url)
if out: if out:
address, name, validated = out address, name, validated = out
return { return {
@@ -132,19 +131,17 @@ class Contacts(dict, Logger):
alias = config.OPENALIAS_ID alias = config.OPENALIAS_ID
if alias: if alias:
alias = str(alias) alias = str(alias)
def f(): async def f():
self.alias_info = self._resolve_openalias(alias) self.alias_info = await self._resolve_openalias(alias)
trigger_callback('alias_received') trigger_callback('alias_received')
t = threading.Thread(target=f) asyncio.run_coroutine_threadsafe(f(), get_asyncio_loop())
t.daemon = True
t.start()
@classmethod @classmethod
def _resolve_openalias(cls, url: str) -> Optional[Tuple[str, str, bool]]: async def _resolve_openalias(cls, url: str) -> Optional[Tuple[str, str, bool]]:
# support email-style addresses, per the OA standard # support email-style addresses, per the OA standard
url = url.replace('@', '.') url = url.replace('@', '.')
try: try:
records, validated = dnssec.query(url, dns.rdatatype.TXT) records, validated = await dnssec.query(url, dns.rdatatype.TXT)
except DNSException as e: except DNSException as e:
_logger.info(f'Error resolving openalias: {repr(e)}') _logger.info(f'Error resolving openalias: {repr(e)}')
return None return None
@@ -161,7 +158,6 @@ class Contacts(dict, Logger):
if not address: if not address:
continue continue
return address, name, validated return address, name, validated
return None
return None return None
@staticmethod @staticmethod

View File

@@ -33,10 +33,10 @@
import dns import dns
import dns.name import dns.name
import dns.query import dns.asyncquery
import dns.dnssec import dns.dnssec
import dns.message import dns.message
import dns.resolver import dns.asyncresolver
import dns.rdatatype import dns.rdatatype
import dns.rdtypes.ANY.NS import dns.rdtypes.ANY.NS
import dns.rdtypes.ANY.CNAME import dns.rdtypes.ANY.CNAME
@@ -53,6 +53,7 @@ import dns.rdtypes.IN.A
import dns.rdtypes.IN.AAAA import dns.rdtypes.IN.AAAA
from .logging import get_logger from .logging import get_logger
from typing import Tuple
_logger = get_logger(__name__) _logger = get_logger(__name__)
@@ -67,9 +68,9 @@ trust_anchors = [
] ]
def _check_query(ns, sub, _type, keys): async def _check_query(ns, sub, _type, keys) -> dns.rrset.RRset:
q = dns.message.make_query(sub, _type, want_dnssec=True) q = dns.message.make_query(sub, _type, want_dnssec=True)
response = dns.query.tcp(q, ns, timeout=5) response = await dns.asyncquery.tcp(q, ns, timeout=5)
assert response.rcode() == 0, 'No answer' assert response.rcode() == 0, 'No answer'
answer = response.answer answer = response.answer
assert len(answer) != 0, ('No DNS record found', sub, _type) assert len(answer) != 0, ('No DNS record found', sub, _type)
@@ -86,13 +87,13 @@ def _check_query(ns, sub, _type, keys):
return rrset return rrset
def _get_and_validate(ns, url, _type): async def _get_and_validate(ns, url, _type) -> dns.rrset.RRset:
# get trusted root key # get trusted root key
root_rrset = None root_rrset = None
for dnskey_rr in trust_anchors: for dnskey_rr in trust_anchors:
try: try:
# Check if there is a valid signature for the root dnskey # Check if there is a valid signature for the root dnskey
root_rrset = _check_query(ns, '', dns.rdatatype.DNSKEY, {dns.name.root: dnskey_rr}) root_rrset = await _check_query(ns, '', dns.rdatatype.DNSKEY, {dns.name.root: dnskey_rr})
break break
except dns.dnssec.ValidationFailure: except dns.dnssec.ValidationFailure:
# It's OK as long as one key validates # It's OK as long as one key validates
@@ -107,16 +108,16 @@ def _get_and_validate(ns, url, _type):
name = dns.name.from_text(sub) name = dns.name.from_text(sub)
# If server is authoritative, don't fetch DNSKEY # If server is authoritative, don't fetch DNSKEY
query = dns.message.make_query(sub, dns.rdatatype.NS) query = dns.message.make_query(sub, dns.rdatatype.NS)
response = dns.query.udp(query, ns, 3) response = await dns.asyncquery.udp(query, ns, 3)
assert response.rcode() == dns.rcode.NOERROR, "query error" assert response.rcode() == dns.rcode.NOERROR, "query error"
rrset = response.authority[0] if len(response.authority) > 0 else response.answer[0] rrset = response.authority[0] if len(response.authority) > 0 else response.answer[0]
rr = rrset[0] rr = rrset[0]
if rr.rdtype == dns.rdatatype.SOA: if rr.rdtype == dns.rdatatype.SOA:
continue continue
# get DNSKEY (self-signed) # get DNSKEY (self-signed)
rrset = _check_query(ns, sub, dns.rdatatype.DNSKEY, None) rrset = await _check_query(ns, sub, dns.rdatatype.DNSKEY, None)
# get DS (signed by parent) # get DS (signed by parent)
ds_rrset = _check_query(ns, sub, dns.rdatatype.DS, keys) ds_rrset = await _check_query(ns, sub, dns.rdatatype.DS, keys)
# verify that a signed DS validates DNSKEY # verify that a signed DS validates DNSKEY
for ds in ds_rrset: for ds in ds_rrset:
for dnskey in rrset: for dnskey in rrset:
@@ -132,20 +133,20 @@ def _get_and_validate(ns, url, _type):
# set key for next iteration # set key for next iteration
keys = {name: rrset} keys = {name: rrset}
# get TXT record (signed by zone) # get TXT record (signed by zone)
rrset = _check_query(ns, url, _type, keys) rrset = await _check_query(ns, url, _type, keys)
return rrset return rrset
def query(url, rtype): async def query(url, rtype) -> Tuple[dns.rrset.RRset, bool]:
# FIXME this method is not using the network proxy. (although the proxy might not support UDP?) # FIXME this method is not using the network proxy. (although the proxy might not support UDP?)
# 8.8.8.8 is Google's public DNS server # 8.8.8.8 is Google's public DNS server
nameservers = ['8.8.8.8'] nameservers = ['8.8.8.8']
ns = nameservers[0] ns = nameservers[0]
try: try:
out = _get_and_validate(ns, url, rtype) out = await _get_and_validate(ns, url, rtype)
validated = True validated = True
except Exception as e: except Exception as e:
_logger.info(f"DNSSEC error: {repr(e)}") _logger.info(f"DNSSEC error: {repr(e)}")
out = dns.resolver.resolve(url, rtype) out = await dns.asyncresolver.resolve(url, rtype)
validated = False validated = False
return out, validated return out, validated

View File

@@ -17,6 +17,7 @@ from electrum.lnutil import format_short_channel_id
from electrum.bitcoin import COIN, address_to_script from electrum.bitcoin import COIN, address_to_script
from electrum.paymentrequest import PaymentRequest from electrum.paymentrequest import PaymentRequest
from electrum.payment_identifier import PaymentIdentifier, PaymentIdentifierState, PaymentIdentifierType from electrum.payment_identifier import PaymentIdentifier, PaymentIdentifierState, PaymentIdentifierType
from electrum.network import Network
from .qetypes import QEAmount from .qetypes import QEAmount
from .qewallet import QEWallet from .qewallet import QEWallet
@@ -523,7 +524,7 @@ class QEInvoiceParser(QEInvoice):
def _bip70_payment_request_resolved(self, pr: 'PaymentRequest'): def _bip70_payment_request_resolved(self, pr: 'PaymentRequest'):
self._logger.debug('resolved payment request') self._logger.debug('resolved payment request')
if pr.verify(): if Network.run_from_another_thread(pr.verify()):
invoice = Invoice.from_bip70_payreq(pr, height=0) invoice = Invoice.from_bip70_payreq(pr, height=0)
if self._wallet.wallet.get_invoice_status(invoice) == PR_PAID: if self._wallet.wallet.get_invoice_status(invoice) == PR_PAID:
self.validationError.emit('unknown', _('Invoice already paid')) self.validationError.emit('unknown', _('Invoice already paid'))

View File

@@ -1658,7 +1658,7 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger, QtEventListener):
grid.addWidget(QLabel(format_time(invoice.exp + invoice.time)), 4, 1) grid.addWidget(QLabel(format_time(invoice.exp + invoice.time)), 4, 1)
if invoice.bip70: if invoice.bip70:
pr = paymentrequest.PaymentRequest(bytes.fromhex(invoice.bip70)) pr = paymentrequest.PaymentRequest(bytes.fromhex(invoice.bip70))
pr.verify() Network.run_from_another_thread(pr.verify())
grid.addWidget(QLabel(_("Requestor") + ':'), 5, 0) grid.addWidget(QLabel(_("Requestor") + ':'), 5, 0)
grid.addWidget(QLabel(pr.get_requestor()), 5, 1) grid.addWidget(QLabel(pr.get_requestor()), 5, 1)
grid.addWidget(QLabel(_("Signature") + ':'), 6, 0) grid.addWidget(QLabel(_("Signature") + ':'), 6, 0)

View File

@@ -349,7 +349,7 @@ class PaymentIdentifier(Logger):
self.set_state(PaymentIdentifierState.NOT_FOUND) self.set_state(PaymentIdentifierState.NOT_FOUND)
elif self.bip70: elif self.bip70:
pr = await paymentrequest.get_payment_request(self.bip70) pr = await paymentrequest.get_payment_request(self.bip70)
if pr.verify(): if await pr.verify():
self.bip70_data = pr self.bip70_data = pr
self.set_state(PaymentIdentifierState.MERCHANT_NOTIFY) self.set_state(PaymentIdentifierState.MERCHANT_NOTIFY)
else: else:
@@ -653,7 +653,7 @@ class PaymentIdentifier(Logger):
if parts and len(parts) > 0 and bitcoin.is_address(parts[0]): if parts and len(parts) > 0 and bitcoin.is_address(parts[0]):
return None return None
try: try:
data = self.contacts.resolve(key) # TODO: don't use contacts as delegate to resolve openalias, separate. data = await self.contacts.resolve(key) # TODO: don't use contacts as delegate to resolve openalias, separate.
return data return data
except AliasNotFoundException as e: except AliasNotFoundException as e:
self.logger.info(f'OpenAlias not found: {repr(e)}') self.logger.info(f'OpenAlias not found: {repr(e)}')

View File

@@ -40,7 +40,8 @@ except ImportError:
sys.exit("Error: could not find paymentrequest_pb2.py. Create it with 'contrib/generate_payreqpb2.sh'") sys.exit("Error: could not find paymentrequest_pb2.py. Create it with 'contrib/generate_payreqpb2.sh'")
from . import bitcoin, constants, util, transaction, x509, rsakey from . import bitcoin, constants, util, transaction, x509, rsakey
from .util import bfh, make_aiohttp_session, error_text_bytes_to_safe_str, get_running_loop from .util import (bfh, make_aiohttp_session, error_text_bytes_to_safe_str, get_running_loop,
get_asyncio_loop)
from .invoices import Invoice, get_id_from_onchain_outputs from .invoices import Invoice, get_id_from_onchain_outputs
from .bitcoin import address_to_script from .bitcoin import address_to_script
from .transaction import PartialTxOutput from .transaction import PartialTxOutput
@@ -104,10 +105,8 @@ async def get_payment_request(url: str) -> 'PaymentRequest':
data = None data = None
error = f"Unknown scheme for payment request. URL: {url}" error = f"Unknown scheme for payment request. URL: {url}"
pr = PaymentRequest(data, error=error) pr = PaymentRequest(data, error=error)
loop = get_running_loop() # do x509/dnssec verification now. we still expect the caller to at least check pr.error!
# do x509/dnssec verification now (in separate thread, to avoid blocking event loop). await pr.verify()
# we still expect the caller to at least check pr.error!
await loop.run_in_executor(None, pr.verify)
return pr return pr
@@ -153,7 +152,7 @@ class PaymentRequest:
self.memo = self.details.memo self.memo = self.details.memo
self.payment_url = self.details.payment_url self.payment_url = self.details.payment_url
def verify(self) -> bool: async def verify(self) -> bool:
# FIXME: we should enforce that this method was called before we attempt payment # FIXME: we should enforce that this method was called before we attempt payment
# note: this method might do network requests (at least for verify_dnssec) # note: this method might do network requests (at least for verify_dnssec)
if self._verified_success is True: if self._verified_success is True:
@@ -176,7 +175,7 @@ class PaymentRequest:
if pr.pki_type in ["x509+sha256", "x509+sha1"]: if pr.pki_type in ["x509+sha256", "x509+sha1"]:
return self.verify_x509(pr) return self.verify_x509(pr)
elif pr.pki_type in ["dnssec+btc", "dnssec+ecdsa"]: elif pr.pki_type in ["dnssec+btc", "dnssec+ecdsa"]:
return self.verify_dnssec(pr) return await self.verify_dnssec(pr)
else: else:
self.error = "ERROR: Unsupported PKI Type for Message Signature" self.error = "ERROR: Unsupported PKI Type for Message Signature"
return False return False
@@ -222,10 +221,10 @@ class PaymentRequest:
self._verified_success = True self._verified_success = True
return True return True
def verify_dnssec(self, pr): async def verify_dnssec(self, pr):
sig = pr.signature sig = pr.signature
alias = pr.pki_data alias = pr.pki_data
info = Contacts.resolve_openalias(alias) info: dict = await Contacts.resolve_openalias(alias)
if info.get('validated') is not True: if info.get('validated') is not True:
self.error = "Alias verification failed (DNSSEC)" self.error = "Alias verification failed (DNSSEC)"
return False return False