1
0

network: replace "server" strings with ServerAddr objects

This commit is contained in:
SomberNight
2020-04-14 16:56:17 +02:00
parent ef2ff11926
commit cf1f2ba4dc
6 changed files with 172 additions and 103 deletions

View File

@@ -29,7 +29,7 @@ import sys
import traceback
import asyncio
import socket
from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set
from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple
from collections import defaultdict
from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address
import itertools
@@ -198,22 +198,57 @@ class _RSClient(RSClient):
raise ConnectError(e) from e
def deserialize_server(server_str: str) -> Tuple[str, str, str]:
# host might be IPv6 address, hence do rsplit:
host, port, protocol = str(server_str).rsplit(':', 2)
if not host:
raise ValueError('host must not be empty')
if host[0] == '[' and host[-1] == ']': # IPv6
host = host[1:-1]
if protocol not in ('s', 't'):
raise ValueError('invalid network protocol: {}'.format(protocol))
net_addr = NetAddress(host, port) # this validates host and port
host = str(net_addr.host) # canonical form (if e.g. IPv6 address)
return host, port, protocol
class ServerAddr:
def __init__(self, host: str, port: Union[int, str], *, protocol: str = None):
assert isinstance(host, str), repr(host)
if protocol is None:
protocol = 's'
if not host:
raise ValueError('host must not be empty')
if host[0] == '[' and host[-1] == ']': # IPv6
host = host[1:-1]
try:
net_addr = NetAddress(host, port) # this validates host and port
except Exception as e:
raise ValueError(f"cannot construct ServerAddr: invalid host or port (host={host}, port={port})") from e
if protocol not in ('s', 't'):
raise ValueError(f"invalid network protocol: {protocol}")
self.host = str(net_addr.host) # canonical form (if e.g. IPv6 address)
self.port = int(net_addr.port)
self.protocol = protocol
self._net_addr_str = str(net_addr)
def serialize_server(host: str, port: Union[str, int], protocol: str) -> str:
return str(':'.join([host, str(port), protocol]))
@classmethod
def from_str(cls, s: str) -> 'ServerAddr':
# host might be IPv6 address, hence do rsplit:
host, port, protocol = str(s).rsplit(':', 2)
return ServerAddr(host=host, port=port, protocol=protocol)
def __str__(self):
return '{}:{}'.format(self.net_addr_str(), self.protocol)
def to_json(self) -> str:
return str(self)
def __repr__(self):
return f'<ServerAddr host={self.host} port={self.port} protocol={self.protocol}>'
def net_addr_str(self) -> str:
return self._net_addr_str
def __eq__(self, other):
if not isinstance(other, ServerAddr):
return False
return (self.host == other.host
and self.port == other.port
and self.protocol == other.protocol)
def __ne__(self, other):
return not (self == other)
def __hash__(self):
return hash((self.host, self.port, self.protocol))
def _get_cert_path_for_host(*, config: 'SimpleConfig', host: str) -> str:
@@ -232,12 +267,10 @@ class Interface(Logger):
LOGGING_SHORTCUT = 'i'
def __init__(self, network: 'Network', server: str, proxy: Optional[dict]):
def __init__(self, *, network: 'Network', server: ServerAddr, proxy: Optional[dict]):
self.ready = asyncio.Future()
self.got_disconnected = asyncio.Future()
self.server = server
self.host, self.port, self.protocol = deserialize_server(self.server)
self.port = int(self.port)
Logger.__init__(self)
assert network.config.path
self.cert_path = _get_cert_path_for_host(config=network.config, host=self.host)
@@ -259,8 +292,20 @@ class Interface(Logger):
self.network.taskgroup.spawn(self.run()), self.network.asyncio_loop)
self.taskgroup = SilentTaskGroup()
@property
def host(self):
return self.server.host
@property
def port(self):
return self.server.port
@property
def protocol(self):
return self.server.protocol
def diagnostic_name(self):
return str(NetAddress(self.host, self.port))
return self.server.net_addr_str()
def __str__(self):
return f"<Interface {self.diagnostic_name()}>"