network: replace "server" strings with ServerAddr objects
This commit is contained in:
@@ -29,7 +29,7 @@ import sys
|
||||
import traceback
|
||||
import asyncio
|
||||
import socket
|
||||
from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set
|
||||
from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple
|
||||
from collections import defaultdict
|
||||
from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address
|
||||
import itertools
|
||||
@@ -198,22 +198,57 @@ class _RSClient(RSClient):
|
||||
raise ConnectError(e) from e
|
||||
|
||||
|
||||
def deserialize_server(server_str: str) -> Tuple[str, str, str]:
|
||||
# host might be IPv6 address, hence do rsplit:
|
||||
host, port, protocol = str(server_str).rsplit(':', 2)
|
||||
if not host:
|
||||
raise ValueError('host must not be empty')
|
||||
if host[0] == '[' and host[-1] == ']': # IPv6
|
||||
host = host[1:-1]
|
||||
if protocol not in ('s', 't'):
|
||||
raise ValueError('invalid network protocol: {}'.format(protocol))
|
||||
net_addr = NetAddress(host, port) # this validates host and port
|
||||
host = str(net_addr.host) # canonical form (if e.g. IPv6 address)
|
||||
return host, port, protocol
|
||||
class ServerAddr:
|
||||
|
||||
def __init__(self, host: str, port: Union[int, str], *, protocol: str = None):
|
||||
assert isinstance(host, str), repr(host)
|
||||
if protocol is None:
|
||||
protocol = 's'
|
||||
if not host:
|
||||
raise ValueError('host must not be empty')
|
||||
if host[0] == '[' and host[-1] == ']': # IPv6
|
||||
host = host[1:-1]
|
||||
try:
|
||||
net_addr = NetAddress(host, port) # this validates host and port
|
||||
except Exception as e:
|
||||
raise ValueError(f"cannot construct ServerAddr: invalid host or port (host={host}, port={port})") from e
|
||||
if protocol not in ('s', 't'):
|
||||
raise ValueError(f"invalid network protocol: {protocol}")
|
||||
self.host = str(net_addr.host) # canonical form (if e.g. IPv6 address)
|
||||
self.port = int(net_addr.port)
|
||||
self.protocol = protocol
|
||||
self._net_addr_str = str(net_addr)
|
||||
|
||||
def serialize_server(host: str, port: Union[str, int], protocol: str) -> str:
|
||||
return str(':'.join([host, str(port), protocol]))
|
||||
@classmethod
|
||||
def from_str(cls, s: str) -> 'ServerAddr':
|
||||
# host might be IPv6 address, hence do rsplit:
|
||||
host, port, protocol = str(s).rsplit(':', 2)
|
||||
return ServerAddr(host=host, port=port, protocol=protocol)
|
||||
|
||||
def __str__(self):
|
||||
return '{}:{}'.format(self.net_addr_str(), self.protocol)
|
||||
|
||||
def to_json(self) -> str:
|
||||
return str(self)
|
||||
|
||||
def __repr__(self):
|
||||
return f'<ServerAddr host={self.host} port={self.port} protocol={self.protocol}>'
|
||||
|
||||
def net_addr_str(self) -> str:
|
||||
return self._net_addr_str
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, ServerAddr):
|
||||
return False
|
||||
return (self.host == other.host
|
||||
and self.port == other.port
|
||||
and self.protocol == other.protocol)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.host, self.port, self.protocol))
|
||||
|
||||
|
||||
def _get_cert_path_for_host(*, config: 'SimpleConfig', host: str) -> str:
|
||||
@@ -232,12 +267,10 @@ class Interface(Logger):
|
||||
|
||||
LOGGING_SHORTCUT = 'i'
|
||||
|
||||
def __init__(self, network: 'Network', server: str, proxy: Optional[dict]):
|
||||
def __init__(self, *, network: 'Network', server: ServerAddr, proxy: Optional[dict]):
|
||||
self.ready = asyncio.Future()
|
||||
self.got_disconnected = asyncio.Future()
|
||||
self.server = server
|
||||
self.host, self.port, self.protocol = deserialize_server(self.server)
|
||||
self.port = int(self.port)
|
||||
Logger.__init__(self)
|
||||
assert network.config.path
|
||||
self.cert_path = _get_cert_path_for_host(config=network.config, host=self.host)
|
||||
@@ -259,8 +292,20 @@ class Interface(Logger):
|
||||
self.network.taskgroup.spawn(self.run()), self.network.asyncio_loop)
|
||||
self.taskgroup = SilentTaskGroup()
|
||||
|
||||
@property
|
||||
def host(self):
|
||||
return self.server.host
|
||||
|
||||
@property
|
||||
def port(self):
|
||||
return self.server.port
|
||||
|
||||
@property
|
||||
def protocol(self):
|
||||
return self.server.protocol
|
||||
|
||||
def diagnostic_name(self):
|
||||
return str(NetAddress(self.host, self.port))
|
||||
return self.server.net_addr_str()
|
||||
|
||||
def __str__(self):
|
||||
return f"<Interface {self.diagnostic_name()}>"
|
||||
|
||||
Reference in New Issue
Block a user